[7.x][ML] ML Model Inference Ingest Processor (#49052) (#49257)

* [ML] ML Model Inference Ingest Processor (#49052)

* [ML][Inference] adds lazy model loader and inference (#47410)

This adds a couple of things:

- A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them
- A Model class and its first sub-class LocalModel. Used to cache model information and run inference.
- Transport action and handler for requests to infer against a local model
Related Feature PRs:

* [ML][Inference] Adjust inference configuration option API (#47812)

* [ML][Inference] adds logistic_regression output aggregator (#48075)

* [ML][Inference] Adding read/del trained models (#47882)

* [ML][Inference] Adding inference ingest processor (#47859)

* [ML][Inference] fixing classification inference for ensemble (#48463)

* [ML][Inference] Adding model memory estimations (#48323)

* [ML][Inference] adding more options to inference processor (#48545)

* [ML][Inference] handle string values better in feature extraction (#48584)

* [ML][Inference] Adding _stats endpoint for inference (#48492)

* [ML][Inference] add inference processors and trained models to usage (#47869)

* [ML][Inference] add new flag for optionally including model definition (#48718)

* [ML][Inference] adding license checks (#49056)

* [ML][Inference] Adding memory and compute estimates to inference (#48955)

* fixing version of indexed docs for model inference
This commit is contained in:
Benjamin Trent 2019-11-18 13:19:17 -05:00 committed by GitHub
parent 48f53efd9a
commit eefe7688ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
97 changed files with 7855 additions and 362 deletions

View File

@ -22,6 +22,7 @@ import org.elasticsearch.Version;
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
}
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@ -81,6 +86,8 @@ public class TrainedModelConfig implements ToXContentObject {
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;
private final Long estimatedHeapMemory;
private final Long estimatedOperations;
TrainedModelConfig(String modelId,
String createdBy,
@ -90,7 +97,9 @@ public class TrainedModelConfig implements ToXContentObject {
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata,
TrainedModelInput input) {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
@ -100,6 +109,8 @@ public class TrainedModelConfig implements ToXContentObject {
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input;
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
}
public String getModelId() {
@ -138,6 +149,18 @@ public class TrainedModelConfig implements ToXContentObject {
return input;
}
public ByteSizeValue getEstimatedHeapMemory() {
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
}
public Long getEstimatedHeapMemoryBytes() {
return estimatedHeapMemory;
}
public Long getEstimatedOperations() {
return estimatedOperations;
}
public static Builder builder() {
return new Builder();
}
@ -172,6 +195,12 @@ public class TrainedModelConfig implements ToXContentObject {
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
if (estimatedHeapMemory != null) {
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
}
if (estimatedOperations != null) {
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
}
builder.endObject();
return builder;
}
@ -194,6 +223,8 @@ public class TrainedModelConfig implements ToXContentObject {
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(metadata, that.metadata);
}
@ -206,6 +237,8 @@ public class TrainedModelConfig implements ToXContentObject {
definition,
description,
tags,
estimatedHeapMemory,
estimatedOperations,
metadata,
input);
}
@ -222,6 +255,8 @@ public class TrainedModelConfig implements ToXContentObject {
private List<String> tags;
private TrainedModelDefinition definition;
private TrainedModelInput input;
private Long estimatedHeapMemory;
private Long estimatedOperations;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -277,6 +312,16 @@ public class TrainedModelConfig implements ToXContentObject {
return this;
}
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
public Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
@ -287,7 +332,9 @@ public class TrainedModelConfig implements ToXContentObject {
definition,
tags,
metadata,
input);
input,
estimatedHeapMemory,
estimatedOperations);
}
}

View File

@ -64,7 +64,10 @@ public class TrainedModelConfigTests extends AbstractXContentTestCase<TrainedMod
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong());
}
@Override

View File

@ -33,6 +33,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
public class IngestStats implements Writeable, ToXContentFragment {
@ -150,6 +151,21 @@ public class IngestStats implements Writeable, ToXContentFragment {
return processorStats;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats that = (IngestStats) o;
return Objects.equals(totalStats, that.totalStats)
&& Objects.equals(pipelineStats, that.pipelineStats)
&& Objects.equals(processorStats, that.processorStats);
}
@Override
public int hashCode() {
return Objects.hash(totalStats, pipelineStats, processorStats);
}
public static class Stats implements Writeable, ToXContentFragment {
private final long ingestCount;
@ -218,6 +234,22 @@ public class IngestStats implements Writeable, ToXContentFragment {
builder.field("failed", ingestFailedCount);
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.Stats that = (IngestStats.Stats) o;
return Objects.equals(ingestCount, that.ingestCount)
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
&& Objects.equals(ingestCurrent, that.ingestCurrent);
}
@Override
public int hashCode() {
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
}
}
/**
@ -270,6 +302,20 @@ public class IngestStats implements Writeable, ToXContentFragment {
public Stats getStats() {
return stats;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
return Objects.equals(pipelineId, that.pipelineId)
&& Objects.equals(stats, that.stats);
}
@Override
public int hashCode() {
return Objects.hash(pipelineId, stats);
}
}
/**
@ -297,5 +343,21 @@ public class IngestStats implements Writeable, ToXContentFragment {
public Stats getStats() {
return stats;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
return Objects.equals(name, that.name)
&& Objects.equals(type, that.type)
&& Objects.equals(stats, that.stats);
}
@Override
public int hashCode() {
return Objects.hash(name, type, stats);
}
}
}

View File

@ -88,6 +88,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
@ -109,6 +110,9 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
@ -153,6 +157,19 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.P
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@ -371,6 +388,10 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
StopDataFrameAnalyticsAction.INSTANCE,
EvaluateDataFrameAction.INSTANCE,
EstimateMemoryUsageAction.INSTANCE,
InferModelAction.INSTANCE,
GetTrainedModelsAction.INSTANCE,
DeleteTrainedModelAction.INSTANCE,
GetTrainedModelsStatsAction.INSTANCE,
// security
ClearRealmCacheAction.INSTANCE,
ClearRolesCacheAction.INSTANCE,
@ -519,6 +540,16 @@ public class XPackClientPlugin extends Plugin implements ActionPlugin, NetworkPl
new NamedWriteableRegistry.Entry(OutputAggregator.class,
LogisticRegression.NAME.getPreferredName(),
LogisticRegression::new),
// ML - Inference Results
new NamedWriteableRegistry.Entry(InferenceResults.class,
ClassificationInferenceResults.NAME,
ClassificationInferenceResults::new),
new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new),
// ML - Inference Configuration
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),

View File

@ -29,10 +29,12 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
public static final String CREATED_BY = "created_by";
public static final String NODE_COUNT = "node_count";
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
public static final String INFERENCE_FIELD = "inference";
private final Map<String, Object> jobsUsage;
private final Map<String, Object> datafeedsUsage;
private final Map<String, Object> analyticsUsage;
private final Map<String, Object> inferenceUsage;
private final int nodeCount;
public MachineLearningFeatureSetUsage(boolean available,
@ -40,11 +42,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
Map<String, Object> jobsUsage,
Map<String, Object> datafeedsUsage,
Map<String, Object> analyticsUsage,
Map<String, Object> inferenceUsage,
int nodeCount) {
super(XPackField.MACHINE_LEARNING, available, enabled);
this.jobsUsage = Objects.requireNonNull(jobsUsage);
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
this.nodeCount = nodeCount;
}
@ -57,12 +61,17 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
} else {
this.analyticsUsage = Collections.emptyMap();
}
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
this.inferenceUsage = in.readMap();
} else {
this.inferenceUsage = Collections.emptyMap();
}
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
this.nodeCount = in.readInt();
} else {
this.nodeCount = -1;
}
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
@ -72,10 +81,13 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
out.writeMap(analyticsUsage);
}
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
out.writeMap(inferenceUsage);
}
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
out.writeInt(nodeCount);
}
}
}
@Override
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
@ -83,6 +95,7 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
builder.field(JOBS_FIELD, jobsUsage);
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
builder.field(INFERENCE_FIELD, inferenceUsage);
if (nodeCount >= 0) {
builder.field(NODE_COUNT, nodeCount);
}

View File

@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
public class DeleteTrainedModelAction extends ActionType<AcknowledgedResponse> {
public static final DeleteTrainedModelAction INSTANCE = new DeleteTrainedModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/delete";
private DeleteTrainedModelAction() {
super(NAME, AcknowledgedResponse::new);
}
public static class Request extends AcknowledgedRequest<Request> implements ToXContentFragment {
private String id;
public Request(StreamInput in) throws IOException {
super(in);
id = in.readString();
}
public Request() {}
public Request(String id) {
this.id = ExceptionsHelper.requireNonNull(id, TrainedModelConfig.MODEL_ID);
}
public String getId() {
return id;
}
@Override
public ActionRequestValidationException validate() {
return null;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), id);
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DeleteTrainedModelAction.Request request = (DeleteTrainedModelAction.Request) o;
return Objects.equals(id, request.id);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(id);
}
@Override
public int hashCode() {
return Objects.hash(id);
}
}
}

View File

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
public static final GetTrainedModelsAction INSTANCE = new GetTrainedModelsAction();
public static final String NAME = "cluster:monitor/xpack/ml/inference/get";
private GetTrainedModelsAction() {
super(NAME, Response::new);
}
public static class Request extends AbstractGetResourcesRequest {
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
private final boolean includeModelDefinition;
public Request(String id, boolean includeModelDefinition) {
setResourceId(id);
setAllowNoResources(true);
this.includeModelDefinition = includeModelDefinition;
}
public Request(StreamInput in) throws IOException {
super(in);
this.includeModelDefinition = in.readBoolean();
}
@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
}
public boolean isIncludeModelDefinition() {
return includeModelDefinition;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(includeModelDefinition);
}
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), includeModelDefinition);
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Request other = (Request) obj;
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
}
}
public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
public static final ParseField RESULTS_FIELD = new ParseField("trained_model_configs");
public Response(StreamInput in) throws IOException {
super(in);
}
public Response(QueryPage<TrainedModelConfig> trainedModels) {
super(trainedModels);
}
@Override
protected Reader<TrainedModelConfig> getReader() {
return TrainedModelConfig::new;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private long totalCount;
private List<TrainedModelConfig> configs = Collections.emptyList();
private Builder() {
}
public Builder setTotalCount(long totalCount) {
this.totalCount = totalCount;
return this;
}
public Builder setModels(List<TrainedModelConfig> configs) {
this.configs = configs;
return this;
}
public Response build() {
return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD));
}
}
}
}

View File

@ -0,0 +1,194 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStatsAction.Response> {
public static final GetTrainedModelsStatsAction INSTANCE = new GetTrainedModelsStatsAction();
public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
private GetTrainedModelsStatsAction() {
super(NAME, GetTrainedModelsStatsAction.Response::new);
}
public static class Request extends AbstractGetResourcesRequest {
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
public Request() {
setAllowNoResources(true);
}
public Request(String id) {
setResourceId(id);
setAllowNoResources(true);
}
public Request(StreamInput in) throws IOException {
super(in);
}
@Override
public String getResourceIdField() {
return TrainedModelConfig.MODEL_ID.getPreferredName();
}
}
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
public RequestBuilder(ElasticsearchClient client, GetTrainedModelsStatsAction action) {
super(client, action, new Request());
}
}
public static class Response extends AbstractGetResourcesResponse<Response.TrainedModelStats> {
public static class TrainedModelStats implements ToXContentObject, Writeable {
private final String modelId;
private final IngestStats ingestStats;
private final int pipelineCount;
private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
Collections.emptyList(),
Collections.emptyMap());
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) {
this.modelId = Objects.requireNonNull(modelId);
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
if (pipelineCount < 0) {
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
}
this.pipelineCount = pipelineCount;
}
public TrainedModelStats(StreamInput in) throws IOException {
modelId = in.readString();
ingestStats = new IngestStats(in);
pipelineCount = in.readVInt();
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID.getPreferredName(), modelId);
builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
if (pipelineCount > 0) {
// Ingest stats is a fragment
ingestStats.toXContent(builder, params);
}
builder.endObject();
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
ingestStats.writeTo(out);
out.writeVInt(pipelineCount);
}
@Override
public int hashCode() {
return Objects.hash(modelId, ingestStats, pipelineCount);
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
TrainedModelStats other = (TrainedModelStats) obj;
return Objects.equals(this.modelId, other.modelId)
&& Objects.equals(this.ingestStats, other.ingestStats)
&& Objects.equals(this.pipelineCount, other.pipelineCount);
}
}
public static final ParseField RESULTS_FIELD = new ParseField("trained_model_stats");
public Response(StreamInput in) throws IOException {
super(in);
}
public Response(QueryPage<Response.TrainedModelStats> trainedModels) {
super(trainedModels);
}
@Override
protected Reader<Response.TrainedModelStats> getReader() {
return Response.TrainedModelStats::new;
}
public static class Builder {
private long totalModelCount;
private Set<String> expandedIds;
private Map<String, IngestStats> ingestStatsMap;
public Builder setTotalModelCount(long totalModelCount) {
this.totalModelCount = totalModelCount;
return this;
}
public Builder setExpandedIds(Set<String> expandedIds) {
this.expandedIds = expandedIds;
return this;
}
public Set<String> getExpandedIds() {
return this.expandedIds;
}
public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
this.ingestStatsMap = ingestStatsByModelId;
return this;
}
public Response build() {
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
expandedIds.forEach(id -> {
IngestStats ingestStats = ingestStatsMap.get(id);
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ?
0 :
ingestStats.getPipelineStats().size()));
});
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
}
}
}
}

View File

@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class InferModelAction extends ActionType<InferModelAction.Response> {
public static final InferModelAction INSTANCE = new InferModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/infer";
private InferModelAction() {
super(NAME, Response::new);
}
public static class Request extends ActionRequest {
private final String modelId;
private final List<Map<String, Object>> objectsToInfer;
private final InferenceConfig config;
public Request(String modelId) {
this(modelId, Collections.emptyList(), new RegressionConfig());
}
public Request(String modelId, List<Map<String, Object>> objectsToInfer, InferenceConfig inferenceConfig) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
}
public Request(String modelId, Map<String, Object> objectToInfer, InferenceConfig config) {
this(modelId,
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
config);
}
public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.config = in.readNamedWriteable(InferenceConfig.class);
}
public String getModelId() {
return modelId;
}
public List<Map<String, Object>> getObjectsToInfer() {
return objectsToInfer;
}
public InferenceConfig getConfig() {
return config;
}
@Override
public ActionRequestValidationException validate() {
return null;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(modelId);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
out.writeNamedWriteable(config);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Request that = (InferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(config, that.config)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}
@Override
public int hashCode() {
return Objects.hash(modelId, objectsToInfer, config);
}
}
public static class Response extends ActionResponse {
private final List<InferenceResults> inferenceResults;
public Response(List<InferenceResults> inferenceResults) {
super();
this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults"));
}
public Response(StreamInput in) throws IOException {
super(in);
this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class));
}
public List<InferenceResults> getInferenceResults() {
return inferenceResults;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteableList(inferenceResults);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferModelAction.Response that = (InferModelAction.Response) o;
return Objects.equals(inferenceResults, that.inferenceResults);
}
@Override
public int hashCode() {
return Objects.hash(inferenceResults);
}
}
}

View File

@ -8,7 +8,13 @@ package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
@ -110,6 +116,18 @@ public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
LogisticRegression.NAME.getPreferredName(),
LogisticRegression::new));
// Inference Results
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
ClassificationInferenceResults.NAME,
ClassificationInferenceResults::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new));
// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));
return namedWriteables;
}
}

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final String NAME = "trained_model_config";
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@ -66,6 +71,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
parser.declareObject(TrainedModelConfig.Builder::setInput,
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
INPUT);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
return parser;
}
@ -81,6 +88,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;
private final long estimatedHeapMemory;
private final long estimatedOperations;
private final TrainedModelDefinition definition;
@ -92,7 +101,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata,
TrainedModelInput input) {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@ -102,6 +113,15 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) {
throw new IllegalArgumentException(
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
}
this.estimatedHeapMemory = estimatedHeapMemory;
if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) {
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
}
this.estimatedOperations = estimatedOperations;
}
public TrainedModelConfig(StreamInput in) throws IOException {
@ -114,6 +134,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
metadata = in.readMap();
input = new TrainedModelInput(in);
estimatedHeapMemory = in.readVLong();
estimatedOperations = in.readVLong();
}
public String getModelId() {
@ -157,6 +179,14 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return new Builder();
}
public long getEstimatedHeapMemory() {
return estimatedHeapMemory;
}
public long getEstimatedOperations() {
return estimatedOperations;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
@ -168,6 +198,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
out.writeCollection(tags, StreamOutput::writeString);
out.writeMap(metadata);
input.writeTo(out);
out.writeVLong(estimatedHeapMemory);
out.writeVLong(estimatedOperations);
}
@Override
@ -180,7 +212,6 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
// We don't store the definition in the same document as the configuration
if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) {
builder.field(DEFINITION.getPreferredName(), definition);
@ -193,6 +224,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.field(INPUT.getPreferredName(), input);
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.endObject();
return builder;
}
@ -215,6 +251,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(metadata, that.metadata);
}
@ -228,6 +266,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
description,
tags,
metadata,
estimatedHeapMemory,
estimatedOperations,
input);
}
@ -242,6 +282,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
private Map<String, Object> metadata;
private TrainedModelInput input;
private TrainedModelDefinition definition;
private Long estimatedHeapMemory;
private Long estimatedOperations;
public Builder setModelId(String modelId) {
this.modelId = modelId;
@ -297,6 +339,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
return this;
}
public Builder setEstimatedHeapMemory(long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}
public Builder setEstimatedOperations(long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}
// TODO move to REST level instead of here in the builder
public void validate() {
// We require a definition to be available here even though it will be stored in a different doc
@ -327,6 +379,16 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATE_TIME.getPreferredName());
}
if (estimatedHeapMemory != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
}
if (estimatedOperations != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_OPERATIONS.getPreferredName());
}
}
public TrainedModelConfig build() {
@ -339,7 +401,9 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
definition,
tags,
metadata,
input);
input,
estimatedHeapMemory,
estimatedOperations);
}
}

View File

@ -5,12 +5,16 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
@ -19,6 +23,8 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConst
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@ -27,13 +33,18 @@ import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable {
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class);
public static final String NAME = "trained_model_definition";
public static final String HEAP_MEMORY_ESTIMATION = "heap_memory_estimation";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
@ -106,6 +117,11 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
builder.humanReadableField(HEAP_MEMORY_ESTIMATION + "_bytes",
HEAP_MEMORY_ESTIMATION,
new ByteSizeValue(ramBytesUsed()));
}
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
assert modelId != null;
@ -123,6 +139,15 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return preProcessors;
}
private void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, config);
}
@Override
public String toString() {
return Strings.toString(this);
@ -143,6 +168,24 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable {
return Objects.hash(trainedModel, preProcessors, modelId);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(trainedModel);
size += RamUsageEstimator.sizeOfCollection(preProcessors);
return size;
}
@Override
public Collection<Accountable> getChildResources() {
List<Accountable> accountables = new ArrayList<>(preProcessors.size() + 2);
accountables.add(Accountables.namedAccountable("trained_model", trainedModel));
for(PreProcessor preProcessor : preProcessors) {
accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor));
}
return accountables;
}
public static class Builder {
private List<PreProcessor> preProcessors;

View File

@ -5,7 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -25,6 +27,8 @@ import java.util.Objects;
*/
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
public static final ParseField NAME = new ParseField("frequency_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
@ -143,4 +147,17 @@ public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyP
return Objects.hash(field, featureName, frequencyMap);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(frequencyMap);
return size;
}
@Override
public String toString() {
return Strings.toString(this);
}
}

View File

@ -5,7 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -23,6 +25,7 @@ import java.util.Objects;
*/
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
public static final ParseField NAME = new ParseField("one_hot_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField HOT_MAP = new ParseField("hot_map");
@ -127,4 +130,16 @@ public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyPars
return Objects.hash(field, hotMap);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOfMap(hotMap);
return size;
}
@Override
public String toString() {
return Strings.toString(this);
}
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
@ -14,7 +15,7 @@ import java.util.Map;
* Describes a pre-processor for a defined machine learning model
* This processor should take a set of fields and return the modified set of fields.
*/
public interface PreProcessor extends NamedXContentObject, NamedWriteable {
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
/**
* Process the given fields and their values and return the modified map.

View File

@ -5,7 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -25,6 +27,7 @@ import java.util.Objects;
*/
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
public static final ParseField NAME = new ParseField("target_mean_encoding");
public static final ParseField FIELD = new ParseField("field");
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
@ -158,4 +161,17 @@ public class TargetMeanEncoding implements LenientlyParsedPreProcessor, Strictly
return Objects.hash(field, featureName, meanMap, defaultValue);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(meanMap);
return size;
}
@Override
public String toString() {
return Strings.toString(this);
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class ClassificationInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "classification";
public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label");
public static final ParseField TOP_CLASSES = new ParseField("top_classes");
private final String classificationLabel;
private final List<TopClassEntry> topClasses;
public ClassificationInferenceResults(double value, String classificationLabel, List<TopClassEntry> topClasses) {
super(value);
this.classificationLabel = classificationLabel;
this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses);
}
public ClassificationInferenceResults(StreamInput in) throws IOException {
super(in);
this.classificationLabel = in.readOptionalString();
this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new));
}
public String getClassificationLabel() {
return classificationLabel;
}
public List<TopClassEntry> getTopClasses() {
return topClasses;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(classificationLabel);
out.writeCollection(topClasses);
}
@Override
XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
if (classificationLabel != null) {
builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel);
}
if (topClasses.isEmpty() == false) {
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
}
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
ClassificationInferenceResults that = (ClassificationInferenceResults) object;
return Objects.equals(value(), that.value()) &&
Objects.equals(classificationLabel, that.classificationLabel) &&
Objects.equals(topClasses, that.topClasses);
}
@Override
public int hashCode() {
return Objects.hash(value(), classificationLabel, topClasses);
}
@Override
public String valueAsString() {
return classificationLabel == null ? super.valueAsString() : classificationLabel;
}
@Override
public void writeResult(IngestDocument document, String resultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
if (topClasses.isEmpty()) {
document.setFieldValue(resultField, valueAsString());
} else {
document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public String getName() {
return NAME;
}
public static class TopClassEntry implements ToXContentObject, Writeable {
public final ParseField CLASSIFICATION = new ParseField("classification");
public final ParseField PROBABILITY = new ParseField("probability");
private final String classification;
private final double probability;
public TopClassEntry(String classification, Double probability) {
this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION);
this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY);
}
public TopClassEntry(StreamInput in) throws IOException {
this.classification = in.readString();
this.probability = in.readDouble();
}
public String getClassification() {
return classification;
}
public double getProbability() {
return probability;
}
public Map<String, Object> asValueMap() {
Map<String, Object> map = new HashMap<>(2);
map.put(CLASSIFICATION.getPreferredName(), classification);
map.put(PROBABILITY.getPreferredName(), probability);
return map;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(classification);
out.writeDouble(probability);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(CLASSIFICATION.getPreferredName(), classification);
builder.field(PROBABILITY.getPreferredName(), probability);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
TopClassEntry that = (TopClassEntry) object;
return Objects.equals(classification, that.classification) &&
Objects.equals(probability, that.probability);
}
@Override
public int hashCode() {
return Objects.hash(classification, probability);
}
}
}

View File

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
public interface InferenceResults extends NamedXContentObject, NamedWriteable {
void writeResult(IngestDocument document, String resultField);
}

View File

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import java.io.IOException;
import java.util.Objects;
public class RawInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "raw";
public RawInferenceResults(double value) {
super(value);
}
public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}
@Override
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
}
@Override
public int hashCode() {
return Objects.hash(value());
}
@Override
public void writeResult(IngestDocument document, String resultField) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public String getName() {
return NAME;
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Objects;
public class RegressionInferenceResults extends SingleValueInferenceResults {
public static final String NAME = "regression";
public RegressionInferenceResults(double value) {
super(value);
}
public RegressionInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}
@Override
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
return builder;
}
@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RegressionInferenceResults that = (RegressionInferenceResults) object;
return Objects.equals(value(), that.value());
}
@Override
public int hashCode() {
return Objects.hash(value());
}
@Override
public void writeResult(IngestDocument document, String resultField) {
ExceptionsHelper.requireNonNull(document, "document");
ExceptionsHelper.requireNonNull(resultField, "resultField");
document.setFieldValue(resultField, value());
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public String getName() {
return NAME;
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
public abstract class SingleValueInferenceResults implements InferenceResults {
public final ParseField VALUE = new ParseField("value");
private final double value;
SingleValueInferenceResults(StreamInput in) throws IOException {
value = in.readDouble();
}
SingleValueInferenceResults(double value) {
this.value = value;
}
public Double value() {
return value;
}
public String valueAsString() {
return String.valueOf(value);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(value);
}
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(VALUE.getPreferredName(), value);
innerToXContent(builder, params);
builder.endObject();
return builder;
}
abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException;
}

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
public class ClassificationConfig implements InferenceConfig {
public static final String NAME = "classification";
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);
private final int numTopClasses;
public static ClassificationConfig fromMap(Map<String, Object> map) {
Map<String, Object> options = new HashMap<>(map);
Integer numTopClasses = (Integer)options.remove(NUM_TOP_CLASSES.getPreferredName());
if (options.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", options.keySet());
}
return new ClassificationConfig(numTopClasses);
}
public ClassificationConfig(Integer numTopClasses) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
}
public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt();
}
public int getNumTopClasses() {
return numTopClasses;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeInt(numTopClasses);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses);
}
@Override
public int hashCode() {
return Objects.hash(numTopClasses);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (numTopClasses != 0) {
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
}
builder.endObject();
return builder;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public String getName() {
return NAME;
}
@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.CLASSIFICATION.equals(targetType);
}
@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
public interface InferenceConfig extends NamedXContentObject, NamedWriteable {
boolean isTargetTypeSupported(TargetType targetType);
/**
* All nodes in the cluster must be at least this version
*/
Version getMinimalSupportedVersion();
}

View File

@ -0,0 +1,89 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public final class InferenceHelpers {
private InferenceHelpers() { }
public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
List<String> classificationLabels,
int numToInclude) {
if (numToInclude == 0) {
return Collections.emptyList();
}
int[] sortedIndices = IntStream.range(0, probabilities.size())
.boxed()
.sorted(Comparator.comparing(probabilities::get).reversed())
.mapToInt(i -> i)
.toArray();
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
throw ExceptionsHelper
.serverError(
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
null,
probabilities.size(),
classificationLabels);
}
List<String> labels = classificationLabels == null ?
// If we don't have the labels we should return the top classification values anyways, they will just be numeric
IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) :
classificationLabels;
int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size());
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
for(int i = 0; i < count; i++) {
int idx = sortedIndices[i];
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
}
return topClassEntries;
}
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
assert inferenceValue == Math.rint(inferenceValue);
if (classificationLabels == null) {
return String.valueOf(inferenceValue);
}
int label = Double.valueOf(inferenceValue).intValue();
if (label < 0 || label >= classificationLabels.size()) {
throw ExceptionsHelper.serverError(
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
null,
label,
classificationLabels);
}
return classificationLabels.get(label);
}
public static Double toDouble(Object value) {
if (value instanceof Number) {
return ((Number)value).doubleValue();
}
if (value instanceof String) {
try {
return Double.valueOf((String)value);
} catch (NumberFormatException nfe) {
assert false : "value is not properly formatted double [" + value + "]";
return null;
}
}
return null;
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
/**
* Used by ensemble to pass into sub-models.
*/
public class NullInferenceConfig implements InferenceConfig {
public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();
private NullInferenceConfig() { }
@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return true;
}
@Override
public Version getMinimalSupportedVersion() {
return Version.CURRENT;
}
@Override
public String getWriteableName() {
return "null";
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
public String getName() {
return "null";
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder;
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
public class RegressionConfig implements InferenceConfig {
public static final String NAME = "regression";
private static final Version MIN_SUPPORTED_VERSION = Version.V_7_6_0;
public static RegressionConfig fromMap(Map<String, Object> map) {
if (map.isEmpty() == false) {
throw ExceptionsHelper.badRequestException("Unrecognized fields {}.", map.keySet());
}
return new RegressionConfig();
}
public RegressionConfig() {
}
public RegressionConfig(StreamInput in) {
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
}
@Override
public String getName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
return true;
}
@Override
public int hashCode() {
return Objects.hash(NAME);
}
@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}
@Override
public Version getMinimalSupportedVersion() {
return MIN_SUPPORTED_VERSION;
}
}

View File

@ -5,14 +5,16 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
import java.util.Map;
public interface TrainedModel extends NamedXContentObject, NamedWriteable {
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
/**
* @return List of featureNames expected by the model. In the order that they are expected
@ -23,16 +25,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
* Infer against the provided fields
*
* @param fields The fields and their values to infer against
* @param config The configuration options for inference
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous.
*/
double infer(Map<String, Object> fields);
/**
* @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles
* @return The predicted value.
*/
double infer(List<Double> fields);
InferenceResults infer(Map<String, Object> fields, InferenceConfig config);
/**
* @return {@link TargetType} for the model.
@ -40,26 +37,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
TargetType targetType();
/**
* This gathers the probabilities for each potential classification value.
*
* The probabilities are indexed by classification ordinal label encoding.
* The length of this list is equal to the number of classification labels.
*
* This only should return if the implementation model is inferring classification values and not regression
* @param fields The fields and their values to infer against
* @return The probabilities of each classification value
*/
List<Double> classificationProbability(Map<String, Object> fields);
/**
* @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles
* @return The probabilities of each classification value
*/
List<Double> classificationProbability(List<Double> fields);
/**
* The ordinal encoded list of the classification labels.
* @return Oridinal encoded list of classification labels.
* @return Ordinal encoded list of classification labels.
*/
@Nullable
List<String> classificationLabels();
@ -72,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
* @throws org.elasticsearch.ElasticsearchException if validations fail
*/
void validate();
/**
* @return The estimated number of operations required at inference time
*/
long estimatedNumOperations();
}

View File

@ -5,6 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
@ -12,7 +15,16 @@ import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@ -20,14 +32,20 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class);
// TODO should we have regression/classification sub-classes that accept the builder?
public static final ParseField NAME = new ParseField("ensemble");
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
@ -106,14 +124,18 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
}
@Override
public double infer(Map<String, Object> fields) {
List<Double> processedInferences = inferAndProcess(fields);
return outputAggregator.aggregate(processedInferences);
}
@Override
public double infer(List<Double> fields) {
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> inferenceResults = this.models.stream().map(model -> {
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
assert results instanceof SingleValueInferenceResults;
return ((SingleValueInferenceResults)results).value();
}).collect(Collectors.toList());
List<Double> processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, config);
}
@Override
@ -121,18 +143,27 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return targetType;
}
@Override
public List<Double> classificationProbability(Map<String, Object> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
}
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences));
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationConfig.getNumTopClasses());
double value = outputAggregator.aggregate(processedInferences);
return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences),
classificationLabel(value, classificationLabels),
topClasses);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model");
}
return inferAndProcess(fields);
}
@Override
public List<Double> classificationProbability(List<Double> fields) {
throw new UnsupportedOperationException("Ensemble requires map containing field names and values");
}
@Override
@ -140,11 +171,6 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
return classificationLabels;
}
private List<Double> inferAndProcess(Map<String, Object> fields) {
List<Double> modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList());
return outputAggregator.processValues(modelInferences);
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -204,6 +230,13 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
@Override
public void validate() {
if (outputAggregator.compatibleWith(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"aggregate_output [{}] is not compatible with target_type [{}]",
this.targetType,
outputAggregator.getName()
);
}
if (outputAggregator.expectedValueSize() != null &&
outputAggregator.expectedValueSize() != models.size()) {
throw ExceptionsHelper.badRequestException(
@ -219,10 +252,38 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai
this.models.forEach(TrainedModel::validate);
}
@Override
public long estimatedNumOperations() {
OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average();
assert avg.isPresent() : "unexpected null when calculating number of operations";
// Average operations for each model and the operations required for processing and aggregating with the outputAggregator
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
}
public static Builder builder() {
return new Builder();
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOfCollection(featureNames);
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOfCollection(models);
size += outputAggregator.ramBytesUsed();
return size;
}
@Override
public Collection<Accountable> getChildResources() {
List<Accountable> accountables = new ArrayList<>(models.size() + 1);
for (TrainedModel model : models) {
accountables.add(Accountables.namedAccountable(model.getName(), model));
}
accountables.add(Accountables.namedAccountable(outputAggregator.getName(), outputAggregator));
return Collections.unmodifiableCollection(accountables);
}
public static class Builder {
private List<String> featureNames;
private List<TrainedModel> trainedModels;

View File

@ -6,16 +6,17 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;
@ -24,6 +25,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.sigmoid
public class LogisticRegression implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LogisticRegression.class);
public static final ParseField NAME = new ParseField("logistic_regression");
public static final ParseField WEIGHTS = new ParseField("weights");
@ -48,19 +50,23 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
return LENIENT_PARSER.apply(parser, null);
}
private final List<Double> weights;
private final double[] weights;
LogisticRegression() {
this((List<Double>) null);
}
public LogisticRegression(List<Double> weights) {
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
private LogisticRegression(List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
}
public LogisticRegression(double[] weights) {
this.weights = weights;
}
public LogisticRegression(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
this.weights = in.readDoubleArray();
} else {
this.weights = null;
}
@ -68,18 +74,18 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
@Override
public Integer expectedValueSize() {
return this.weights == null ? null : this.weights.size();
return this.weights == null ? null : this.weights.length;
}
@Override
public List<Double> processValues(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.size()) {
if (weights != null && values.size() != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
double summation = weights == null ?
values.stream().mapToDouble(Double::valueOf).sum() :
IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).sum();
IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).sum();
double probOfClassOne = sigmoid(summation);
assert 0.0 <= probOfClassOne && probOfClassOne <= 1.0;
return Arrays.asList(1.0 - probOfClassOne, probOfClassOne);
@ -108,6 +114,11 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
return NAME.getPreferredName();
}
@Override
public boolean compatibleWith(TargetType targetType) {
return true;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -117,7 +128,7 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeCollection(weights, StreamOutput::writeDouble);
out.writeDoubleArray(weights);
}
}
@ -136,12 +147,17 @@ public class LogisticRegression implements StrictlyParsedOutputAggregator, Lenie
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
LogisticRegression that = (LogisticRegression) o;
return Objects.equals(weights, that.weights);
return Arrays.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
return Arrays.hashCode(weights);
}
@Override
public long ramBytesUsed() {
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}

View File

@ -5,12 +5,14 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.Accountable;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import java.util.List;
public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
public interface OutputAggregator extends NamedXContentObject, NamedWriteable, Accountable {
/**
* @return The expected size of the values array when aggregating. `null` implies there is no expected size.
@ -44,4 +46,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable {
* @return The name of the output aggregator
*/
String getName();
boolean compatibleWith(TargetType targetType);
}

View File

@ -6,15 +6,18 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
@ -23,6 +26,7 @@ import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax
public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
public static final ParseField NAME = new ParseField("weighted_mode");
public static final ParseField WEIGHTS = new ParseField("weights");
@ -47,19 +51,23 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
return LENIENT_PARSER.apply(parser, null);
}
private final List<Double> weights;
private final double[] weights;
WeightedMode() {
this.weights = null;
this((List<Double>) null);
}
public WeightedMode(List<Double> weights) {
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
private WeightedMode(List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
}
public WeightedMode(double[] weights) {
this.weights = weights;
}
public WeightedMode(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
this.weights = in.readDoubleArray();
} else {
this.weights = null;
}
@ -67,13 +75,13 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
@Override
public Integer expectedValueSize() {
return this.weights == null ? null : this.weights.size();
return this.weights == null ? null : this.weights.length;
}
@Override
public List<Double> processValues(List<Double> values) {
Objects.requireNonNull(values, "values must not be null");
if (weights != null && values.size() != weights.size()) {
if (weights != null && values.size() != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
List<Integer> freqArray = new ArrayList<>();
@ -93,7 +101,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
}
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
for (int i = 0; i < freqArray.size(); i++) {
Double weight = weights == null ? 1.0 : weights.get(i);
Double weight = weights == null ? 1.0 : weights[i];
Integer value = freqArray.get(i);
Double frequency = frequencies.get(value) == Double.NEGATIVE_INFINITY ? weight : frequencies.get(value) + weight;
frequencies.set(value, frequency);
@ -123,6 +131,11 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
return NAME.getPreferredName();
}
@Override
public boolean compatibleWith(TargetType targetType) {
return true;
}
@Override
public String getWriteableName() {
return NAME.getPreferredName();
@ -132,7 +145,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeCollection(weights, StreamOutput::writeDouble);
out.writeDoubleArray(weights);
}
}
@ -151,11 +164,17 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedMode that = (WeightedMode) o;
return Objects.equals(weights, that.weights);
return Arrays.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
return Arrays.hashCode(weights);
}
@Override
public long ramBytesUsed() {
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}

View File

@ -6,15 +6,17 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Collections;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@ -23,6 +25,7 @@ import java.util.stream.IntStream;
public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedSum.class);
public static final ParseField NAME = new ParseField("weighted_sum");
public static final ParseField WEIGHTS = new ParseField("weights");
@ -47,19 +50,23 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
return LENIENT_PARSER.apply(parser, null);
}
private final List<Double> weights;
private final double[] weights;
WeightedSum() {
this.weights = null;
this((List<Double>) null);
}
public WeightedSum(List<Double> weights) {
this.weights = weights == null ? null : Collections.unmodifiableList(weights);
private WeightedSum(List<Double> weights) {
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
}
public WeightedSum(double[] weights) {
this.weights = weights;
}
public WeightedSum(StreamInput in) throws IOException {
if (in.readBoolean()) {
this.weights = Collections.unmodifiableList(in.readList(StreamInput::readDouble));
this.weights = in.readDoubleArray();
} else {
this.weights = null;
}
@ -71,10 +78,10 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
if (weights == null) {
return values;
}
if (values.size() != weights.size()) {
if (values.size() != weights.length) {
throw new IllegalArgumentException("values must be the same length as weights.");
}
return IntStream.range(0, weights.size()).mapToDouble(i -> values.get(i) * weights.get(i)).boxed().collect(Collectors.toList());
return IntStream.range(0, weights.length).mapToDouble(i -> values.get(i) * weights[i]).boxed().collect(Collectors.toList());
}
@Override
@ -104,7 +111,7 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(weights != null);
if (weights != null) {
out.writeCollection(weights, StreamOutput::writeDouble);
out.writeDoubleArray(weights);
}
}
@ -123,16 +130,27 @@ public class WeightedSum implements StrictlyParsedOutputAggregator, LenientlyPar
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightedSum that = (WeightedSum) o;
return Objects.equals(weights, that.weights);
return Arrays.equals(weights, that.weights);
}
@Override
public int hashCode() {
return Objects.hash(weights);
return Arrays.hashCode(weights);
}
@Override
public Integer expectedValueSize() {
return weights == null ? null : this.weights.size();
return weights == null ? null : this.weights.length;
}
@Override
public boolean compatibleWith(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}
@Override
public long ramBytesUsed() {
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}

View File

@ -5,6 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -13,7 +16,15 @@ import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -22,6 +33,7 @@ import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
@ -31,8 +43,11 @@ import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel, Accountable {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class);
// TODO should we have regression/classification sub-classes that accept the builder?
public static final ParseField NAME = new ParseField("tree");
@ -72,7 +87,10 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE));
if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
throw new IllegalArgumentException("[tree_structure] must not be empty");
}
this.nodes = Collections.unmodifiableList(nodes);
this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue());
@ -105,20 +123,42 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
@Override
public double infer(Map<String, Object> fields) {
List<Double> features = featureNames.stream().map(f ->
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null
).collect(Collectors.toList());
return infer(features);
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> features = featureNames.stream().map(f -> InferenceHelpers.toDouble(fields.get(f))).collect(Collectors.toList());
return infer(features, config);
}
@Override
public double infer(List<Double> features) {
private InferenceResults infer(List<Double> features, InferenceConfig config) {
TreeNode node = nodes.get(0);
while(node.isLeaf() == false) {
node = nodes.get(node.compare(features));
}
return node.getLeafValue();
return buildResult(node.getLeafValue(), config);
}
private InferenceResults buildResult(Double value, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value);
}
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
classificationProbability(value),
classificationLabels,
classificationConfig.getNumTopClasses());
return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses);
case REGRESSION:
return new RegressionInferenceResults(value);
default:
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
}
}
/**
@ -142,34 +182,15 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
return targetType;
}
@Override
public List<Double> classificationProbability(Map<String, Object> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
List<Double> features = featureNames.stream().map(f ->
fields.get(f) instanceof Number ? ((Number)fields.get(f)).doubleValue() : null)
.collect(Collectors.toList());
return classificationProbability(features);
}
@Override
public List<Double> classificationProbability(List<Double> fields) {
if ((targetType == TargetType.CLASSIFICATION) == false) {
throw new UnsupportedOperationException(
"Cannot determine classification probability with target_type [" + targetType.toString() + "]");
}
double label = infer(fields);
private List<Double> classificationProbability(double inferenceValue) {
// If we are classification, we should assume that the inference return value is whole.
assert label == Math.rint(label);
assert inferenceValue == Math.rint(inferenceValue);
double maxCategory = this.highestOrderCategory.get();
// If we are classification, we should assume that the largest leaf value is whole.
assert maxCategory == Math.rint(maxCategory);
List<Double> list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0));
// TODO, eventually have TreeNodes contain confidence levels
list.set(Double.valueOf(label).intValue(), 1.0);
list.set(Double.valueOf(inferenceValue).intValue(), 1.0);
return list;
}
@ -239,6 +260,12 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
detectCycle();
}
@Override
public long estimatedNumOperations() {
// Grabbing the features from the doc + the depth of the tree
return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size();
}
private void checkTargetType() {
if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) {
throw ExceptionsHelper.badRequestException(
@ -247,9 +274,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
private void detectCycle() {
if (nodes.isEmpty()) {
return;
}
Set<Integer> visited = new HashSet<>(nodes.size());
Queue<Integer> toVisit = new ArrayDeque<>(nodes.size());
toVisit.add(0);
@ -270,10 +294,6 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
}
private void detectMissingNodes() {
if (nodes.isEmpty()) {
return;
}
List<Integer> missingNodes = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
TreeNode currentNode = nodes.get(i);
@ -302,6 +322,24 @@ public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedM
null;
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
size += RamUsageEstimator.sizeOfCollection(featureNames);
size += RamUsageEstimator.sizeOfCollection(nodes);
return size;
}
@Override
public Collection<Accountable> getChildResources() {
List<Accountable> accountables = new ArrayList<>(nodes.size());
for (TreeNode node : nodes) {
accountables.add(Accountables.namedAccountable("tree_node_" + node.getNodeIndex(), node));
}
return Collections.unmodifiableCollection(accountables);
}
public static class Builder {
private List<String> featureNames;
private ArrayList<TreeNode.Builder> nodes;

View File

@ -5,6 +5,9 @@
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
@ -15,14 +18,15 @@ import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.job.config.Operator;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.List;
import java.util.Objects;
public class TreeNode implements ToXContentObject, Writeable {
public class TreeNode implements ToXContentObject, Writeable, Accountable {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TreeNode.class);
public static final String NAME = "tree_node";
public static final ParseField DECISION_TYPE = new ParseField("decision_type");
@ -63,31 +67,31 @@ public class TreeNode implements ToXContentObject, Writeable {
}
private final Operator operator;
private final Double threshold;
private final Integer splitFeature;
private final double threshold;
private final int splitFeature;
private final int nodeIndex;
private final Double splitGain;
private final Double leafValue;
private final double splitGain;
private final double leafValue;
private final boolean defaultLeft;
private final int leftChild;
private final int rightChild;
TreeNode(Operator operator,
Double threshold,
Integer splitFeature,
Integer nodeIndex,
Double splitGain,
Double leafValue,
Boolean defaultLeft,
Integer leftChild,
Integer rightChild) {
private TreeNode(Operator operator,
Double threshold,
Integer splitFeature,
int nodeIndex,
Double splitGain,
Double leafValue,
Boolean defaultLeft,
Integer leftChild,
Integer rightChild) {
this.operator = operator == null ? Operator.LTE : operator;
this.threshold = threshold;
this.splitFeature = splitFeature;
this.nodeIndex = ExceptionsHelper.requireNonNull(nodeIndex, NODE_INDEX.getPreferredName());
this.splitGain = splitGain;
this.leafValue = leafValue;
this.threshold = threshold == null ? Double.NaN : threshold;
this.splitFeature = splitFeature == null ? -1 : splitFeature;
this.nodeIndex = nodeIndex;
this.splitGain = splitGain == null ? Double.NaN : splitGain;
this.leafValue = leafValue == null ? Double.NaN : leafValue;
this.defaultLeft = defaultLeft == null ? false : defaultLeft;
this.leftChild = leftChild == null ? -1 : leftChild;
this.rightChild = rightChild == null ? -1 : rightChild;
@ -95,11 +99,11 @@ public class TreeNode implements ToXContentObject, Writeable {
public TreeNode(StreamInput in) throws IOException {
operator = Operator.readFromStream(in);
threshold = in.readOptionalDouble();
splitFeature = in.readOptionalInt();
splitGain = in.readOptionalDouble();
nodeIndex = in.readInt();
leafValue = in.readOptionalDouble();
threshold = in.readDouble();
splitFeature = in.readInt();
splitGain = in.readDouble();
nodeIndex = in.readVInt();
leafValue = in.readDouble();
defaultLeft = in.readBoolean();
leftChild = in.readInt();
rightChild = in.readInt();
@ -110,23 +114,23 @@ public class TreeNode implements ToXContentObject, Writeable {
return operator;
}
public Double getThreshold() {
public double getThreshold() {
return threshold;
}
public Integer getSplitFeature() {
public int getSplitFeature() {
return splitFeature;
}
public Integer getNodeIndex() {
public int getNodeIndex() {
return nodeIndex;
}
public Double getSplitGain() {
public double getSplitGain() {
return splitGain;
}
public Double getLeafValue() {
public double getLeafValue() {
return leafValue;
}
@ -164,11 +168,11 @@ public class TreeNode implements ToXContentObject, Writeable {
@Override
public void writeTo(StreamOutput out) throws IOException {
operator.writeTo(out);
out.writeOptionalDouble(threshold);
out.writeOptionalInt(splitFeature);
out.writeOptionalDouble(splitGain);
out.writeInt(nodeIndex);
out.writeOptionalDouble(leafValue);
out.writeDouble(threshold);
out.writeInt(splitFeature);
out.writeDouble(splitGain);
out.writeVInt(nodeIndex);
out.writeDouble(leafValue);
out.writeBoolean(defaultLeft);
out.writeInt(leftChild);
out.writeInt(rightChild);
@ -177,12 +181,14 @@ public class TreeNode implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
addOptionalField(builder, DECISION_TYPE, operator);
addOptionalField(builder, THRESHOLD, threshold);
addOptionalField(builder, SPLIT_FEATURE, splitFeature);
addOptionalField(builder, SPLIT_GAIN, splitGain);
builder.field(DECISION_TYPE.getPreferredName(), operator);
addOptionalDouble(builder, THRESHOLD, threshold);
if (splitFeature > -1) {
builder.field(SPLIT_FEATURE.getPreferredName(), splitFeature);
}
addOptionalDouble(builder, SPLIT_GAIN, splitGain);
builder.field(NODE_INDEX.getPreferredName(), nodeIndex);
addOptionalField(builder, LEAF_VALUE, leafValue);
addOptionalDouble(builder, LEAF_VALUE, leafValue);
builder.field(DEFAULT_LEFT.getPreferredName(), defaultLeft);
if (leftChild >= 0) {
builder.field(LEFT_CHILD.getPreferredName(), leftChild);
@ -194,8 +200,8 @@ public class TreeNode implements ToXContentObject, Writeable {
return builder;
}
private void addOptionalField(XContentBuilder builder, ParseField field, Object value) throws IOException {
if (value != null) {
private void addOptionalDouble(XContentBuilder builder, ParseField field, double value) throws IOException {
if (Numbers.isValidDouble(value)) {
builder.field(field.getPreferredName(), value);
}
}
@ -237,7 +243,12 @@ public class TreeNode implements ToXContentObject, Writeable {
public static Builder builder(int nodeIndex) {
return new Builder(nodeIndex);
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE;
}
public static class Builder {
private Operator operator;
private Double threshold;

View File

@ -5,6 +5,8 @@
*/
package org.elasticsearch.xpack.core.ml.inference.utils;
import org.elasticsearch.common.Numbers;
import java.util.List;
import java.util.stream.Collectors;
@ -22,24 +24,24 @@ public final class Statistics {
*/
public static List<Double> softMax(List<Double> values) {
Double expSum = 0.0;
Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null);
Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null);
if (max == null) {
throw new IllegalArgumentException("no valid values present");
}
List<Double> exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max)
List<Double> exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY)
.collect(Collectors.toList());
for (int i = 0; i < exps.size(); i++) {
if (isInvalid(exps.get(i)) == false) {
if (isValid(exps.get(i))) {
Double exp = Math.exp(exps.get(i));
expSum += exp;
exps.set(i, exp);
}
}
for (int i = 0; i < exps.size(); i++) {
if (isInvalid(exps.get(i))) {
exps.set(i, 0.0);
} else {
if (isValid(exps.get(i))) {
exps.set(i, exps.get(i)/expSum);
} else {
exps.set(i, 0.0);
}
}
return exps;
@ -49,8 +51,8 @@ public final class Statistics {
return 1/(1 + Math.exp(-value));
}
public static boolean isInvalid(Double v) {
return v == null || Double.isInfinite(v) || Double.isNaN(v);
private static boolean isValid(Double v) {
return v != null && Numbers.isValidDouble(v);
}
}

View File

@ -83,7 +83,13 @@ public final class Messages {
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
"Getting model definition is not supported when getting more than one model";
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
public static final String JOB_AUDIT_CREATED = "Job created";

View File

@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.notifications;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import org.elasticsearch.xpack.core.common.notifications.Level;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import java.util.Date;
public class InferenceAuditMessage extends AbstractAuditMessage {
//TODO this should be MODEL_ID...
private static final ParseField JOB_ID = Job.ID;
public static final ConstructingObjectParser<InferenceAuditMessage, Void> PARSER =
createParser("ml_inference_audit_message", InferenceAuditMessage::new, JOB_ID);
public InferenceAuditMessage(String resourceId, String message, Level level, Date timestamp, String nodeName) {
super(resourceId, message, level, timestamp, nodeName);
}
@Override
public final String getJobType() {
return "inference";
}
@Override
protected String getResourceField() {
return JOB_ID.getPreferredName();
}
}

View File

@ -43,6 +43,10 @@ public class ExceptionsHelper {
return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id);
}
public static ResourceNotFoundException missingTrainedModel(String modelId) {
return new ResourceNotFoundException("No known trained model with model_id [{}]", modelId);
}
public static ElasticsearchException serverError(String msg) {
return new ElasticsearchException(msg);
}
@ -51,6 +55,10 @@ public class ExceptionsHelper {
return new ElasticsearchException(msg, cause);
}
public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) {
return new ElasticsearchException(msg, cause, args);
}
public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) {
return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args);
}

View File

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction.Request;
public class DeleteTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
@Override
protected Request createTestInstance() {
return new Request(randomAlphaOfLengthBetween(1, 20));
}
@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCase<Request> {
@Override
protected Request createTestInstance() {
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
return request;
}
@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Response;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class GetTrainedModelsStatsActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
int listSize = randomInt(10);
List<Response.TrainedModelStats> trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10))
.limit(listSize).map(id ->
new Response.TrainedModelStats(id,
randomBoolean() ? randomIngestStats() : null,
randomIntBetween(0, 10))
)
.collect(Collectors.toList());
return new Response(new QueryPage<>(trainedModelStats, randomLongBetween(listSize, 1000), Response.RESULTS_FIELD));
}
private IngestStats randomIngestStats() {
List<String> pipelineIds = Stream.generate(()-> randomAlphaOfLength(10))
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList());
return new IngestStats(
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats())));
}
private IngestStats.Stats randomStats(){
return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong());
}
private List<IngestStats.ProcessorStat> randomProcessorStats() {
return Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomIntBetween(0, 10))
.map(name -> new IngestStats.ProcessorStat(name, "inference", randomStats()))
.collect(Collectors.toList());
}
@Override
protected Writeable.Reader<Response> instanceReader() {
return Response::new;
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class InferModelActionRequestTests extends AbstractWireSerializingTestCase<Request> {
@Override
protected Request createTestInstance() {
return randomBoolean() ?
new Request(
randomAlphaOfLength(10),
Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()),
randomInferenceConfig()) :
new Request(
randomAlphaOfLength(10),
randomMap(),
randomInferenceConfig());
}
private static InferenceConfig randomInferenceConfig() {
return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig());
}
private static Map<String, Object> randomMap() {
return Stream.generate(()-> randomAlphaOfLength(10))
.limit(randomInt(10))
.collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10)));
}
@Override
protected Writeable.Reader<Request> instanceReader() {
return Request::new;
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class InferModelActionResponseTests extends AbstractWireSerializingTestCase<Response> {
@Override
protected Response createTestInstance() {
String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME);
return new Response(
Stream.generate(() -> randomInferenceResult(resultType))
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList()));
}
private static InferenceResults randomInferenceResult(String resultType) {
if (resultType.equals(ClassificationInferenceResults.NAME)) {
return ClassificationInferenceResultsTests.createRandomResults();
} else if (resultType.equals(RegressionInferenceResults.NAME)) {
return RegressionInferenceResultsTests.createRandomResults();
} else {
fail("unexpected result type [" + resultType + "]");
return null;
}
}
@Override
protected Writeable.Reader<Response> instanceReader() {
return Response::new;
}
@Override
protected NamedWriteableRegistry getNamedWriteableRegistry() {
List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables());
return new NamedWriteableRegistry(entries);
}
}

View File

@ -33,6 +33,7 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
@ -73,7 +74,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
null, // is not parsed so should not be provided
tags,
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput());
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong());
}
@Override
@ -96,6 +99,16 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
return new NamedWriteableRegistry(entries);
}
@Override
protected ToXContent.Params getToXContentParams() {
return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected boolean assertToXContentEquivalence() {
return false;
}
public void testToXContentWithParams() throws IOException {
TrainedModelConfig config = new TrainedModelConfig(
randomAlphaOfLength(10),
@ -106,7 +119,9 @@ public class TrainedModelConfigTests extends AbstractSerializingTestCase<Trained
TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(),
Collections.emptyList(),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
TrainedModelInputTests.createRandomInput());
TrainedModelInputTests.createRandomInput(),
randomNonNegativeLong(),
randomNonNegativeLong());
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
assertThat(reference.utf8ToString(), containsString("definition"));

View File

@ -5,12 +5,16 @@
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.search.SearchModule;
@ -19,9 +23,9 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
@ -31,26 +35,22 @@ import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
private boolean lenient;
@Before
public void chooseStrictOrLenient() {
lenient = randomBoolean();
}
@Override
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
return TrainedModelDefinition.fromXContent(parser, lenient).build();
return TrainedModelDefinition.fromXContent(parser, true).build();
}
@Override
protected boolean supportsUnknownFields() {
return lenient;
return true;
}
@Override
@ -58,6 +58,16 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
return field -> !field.isEmpty();
}
@Override
protected ToXContent.Params getToXContentParams() {
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
}
@Override
protected boolean assertToXContentEquivalence() {
return false;
}
public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) {
int numberOfProcessors = randomIntBetween(1, 10);
return new TrainedModelDefinition.Builder()
@ -69,7 +79,7 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()));
.setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
}
private static final String ENSEMBLE_MODEL = "" +
@ -273,9 +283,27 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
}
public void testStrictParser() throws IOException {
TrainedModelDefinition.Builder builder = createRandomBuilder("asdf");
BytesReference reference = XContentHelper.toXContent(builder.build(),
XContentType.JSON,
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")),
false);
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
reference,
XContentType.JSON);
XContentParseException exception = expectThrows(XContentParseException.class,
() -> TrainedModelDefinition.fromXContent(parser, false));
assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]"));
}
@Override
protected TrainedModelDefinition createTestInstance() {
return createRandomBuilder(null).build();
return createRandomBuilder(randomAlphaOfLength(10)).build();
}
@Override
@ -298,4 +326,9 @@ public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<Tra
return new NamedWriteableRegistry(entries);
}
public void testRamUsageEstimation() {
TrainedModelDefinition test = createTestInstance();
assertThat(test.ramBytesUsed(), greaterThan(0L));
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase<ClassificationInferenceResults> {
public static ClassificationInferenceResults createRandomResults() {
return new ClassificationInferenceResults(randomDouble(),
randomBoolean() ? null : randomAlphaOfLength(10),
randomBoolean() ? null :
Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry)
.limit(randomIntBetween(0, 10))
.collect(Collectors.toList()));
}
private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() {
return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble());
}
public void testWriteResultsWithClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field", String.class), equalTo("foo"));
}
public void testWriteResultsWithoutClassificationLabel() {
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList());
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0"));
}
@SuppressWarnings("unchecked")
public void testWriteResultsWithTopClasses() {
List<ClassificationInferenceResults.TopClassEntry> entries = Arrays.asList(
new ClassificationInferenceResults.TopClassEntry("foo", 0.7),
new ClassificationInferenceResults.TopClassEntry("bar", 0.2),
new ClassificationInferenceResults.TopClassEntry("baz", 0.1));
ClassificationInferenceResults result = new ClassificationInferenceResults(1.0,
"foo",
entries);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
List<?> list = document.getFieldValue("result_field", List.class);
assertThat(list.size(), equalTo(3));
for(int i = 0; i < 3; i++) {
Map<String, Object> map = (Map<String, Object>)list.get(i);
assertThat(map, equalTo(entries.get(i).asValueMap()));
}
}
@Override
protected ClassificationInferenceResults createTestInstance() {
return createRandomResults();
}
@Override
protected Writeable.Reader<ClassificationInferenceResults> instanceReader() {
return ClassificationInferenceResults::new;
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {
public static RawInferenceResults createRandomResults() {
return new RawInferenceResults(randomDouble());
}
@Override
protected RawInferenceResults createTestInstance() {
return createRandomResults();
}
@Override
protected Writeable.Reader<RawInferenceResults> instanceReader() {
return RawInferenceResults::new;
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import java.util.HashMap;
import static org.hamcrest.Matchers.equalTo;
public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase<RegressionInferenceResults> {
public static RegressionInferenceResults createRandomResults() {
return new RegressionInferenceResults(randomDouble());
}
public void testWriteResults() {
RegressionInferenceResults result = new RegressionInferenceResults(0.3);
IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>());
result.writeResult(document, "result_field");
assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3));
}
@Override
protected RegressionInferenceResults createTestInstance() {
return createRandomResults();
}
@Override
protected Writeable.Reader<RegressionInferenceResults> instanceReader() {
return RegressionInferenceResults::new;
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
public class ClassificationConfigTests extends AbstractWireSerializingTestCase<ClassificationConfig> {
public static ClassificationConfig randomClassificationConfig() {
return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10));
}
public void testFromMap() {
ClassificationConfig expected = new ClassificationConfig(0);
assertThat(ClassificationConfig.fromMap(Collections.emptyMap()), equalTo(expected));
expected = new ClassificationConfig(3);
assertThat(ClassificationConfig.fromMap(Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 3)),
equalTo(expected));
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> ClassificationConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}
@Override
protected ClassificationConfig createTestInstance() {
return randomClassificationConfig();
}
@Override
protected Writeable.Reader<ClassificationConfig> instanceReader() {
return ClassificationConfig::new;
}
}

View File

@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.test.ESTestCase;
import java.util.HashMap;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
public class InferenceHelpersTests extends ESTestCase {
public void testToDoubleFromNumbers() {
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5)));
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5L)));
assertThat(5.0, equalTo(InferenceHelpers.toDouble(5)));
assertThat(0.5, equalTo(InferenceHelpers.toDouble(0.5f)));
}
public void testToDoubleFromString() {
assertThat(0.5, equalTo(InferenceHelpers.toDouble("0.5")));
assertThat(-0.5, equalTo(InferenceHelpers.toDouble("-0.5")));
assertThat(5.0, equalTo(InferenceHelpers.toDouble("5")));
assertThat(-5.0, equalTo(InferenceHelpers.toDouble("-5")));
// if ae are turned off, then we should get a null value
// otherwise, we should expect an assertion failure telling us that the string is improperly formatted
try {
assertThat(InferenceHelpers.toDouble(""), is(nullValue()));
} catch (AssertionError ae) {
assertThat(ae.getMessage(), equalTo("value is not properly formatted double []"));
}
try {
assertThat(InferenceHelpers.toDouble("notADouble"), is(nullValue()));
} catch (AssertionError ae) {
assertThat(ae.getMessage(), equalTo("value is not properly formatted double [notADouble]"));
}
}
public void testToDoubleFromNull() {
assertThat(InferenceHelpers.toDouble(null), is(nullValue()));
}
public void testDoubleFromUnknownObj() {
assertThat(InferenceHelpers.toDouble(new HashMap<>()), is(nullValue()));
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
public class RegressionConfigTests extends AbstractWireSerializingTestCase<RegressionConfig> {
public static RegressionConfig randomRegressionConfig() {
return new RegressionConfig();
}
public void testFromMap() {
RegressionConfig expected = new RegressionConfig();
assertThat(RegressionConfig.fromMap(Collections.emptyMap()), equalTo(expected));
}
public void testFromMapWithUnknownField() {
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
() -> RegressionConfig.fromMap(Collections.singletonMap("some_key", 1)));
assertThat(ex.getMessage(), equalTo("Unrecognized fields [some_key]."));
}
@Override
protected RegressionConfig createTestInstance() {
return randomRegressionConfig();
}
@Override
protected Writeable.Reader<RegressionConfig> instanceReader() {
return RegressionConfig::new;
}
}

View File

@ -15,13 +15,16 @@ import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
@ -68,9 +71,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
.limit(numberOfModels)
.collect(Collectors.toList());
List<Double> weights = randomBoolean() ?
double[] weights = randomBoolean() ?
null :
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).mapToDouble(Double::valueOf).toArray();
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights),
new WeightedSum(weights),
new LogisticRegression(weights));
@ -114,9 +117,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() {
List<String> featureNames = Arrays.asList("foo", "bar");
int numberOfModels = 5;
List<Double> weights = new ArrayList<>(numberOfModels + 2);
double[] weights = new double[numberOfModels + 2];
for (int i = 0; i < numberOfModels + 2; i++) {
weights.add(randomDouble());
weights[i] = randomDouble();
}
OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights));
@ -158,6 +161,27 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
});
}
public void testEnsembleWithAggregatorOutputNotSupportingTargetType() {
List<String> featureNames = Arrays.asList("foo", "bar");
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
Ensemble.builder()
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(
Tree.builder()
.setNodes(TreeNode.builder(0)
.setLeftChild(1)
.setSplitFeature(1)
.setThreshold(randomDouble()))
.setFeatureNames(featureNames)
.build()))
.setClassificationLabels(Arrays.asList("label1", "label2"))
.setTargetType(TargetType.CLASSIFICATION)
.setOutputAggregator(new WeightedSum())
.build()
.validate();
});
}
public void testEnsembleWithTargetTypeAndLabelsMismatch() {
List<String> featureNames = Arrays.asList("foo", "bar");
String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa";
@ -189,6 +213,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setFeatureNames(featureNames)
.build()))
.setTargetType(TargetType.CLASSIFICATION)
.setOutputAggregator(new WeightedMode())
.build()
.validate();
});
@ -236,32 +261,35 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
List<Double> expected = Arrays.asList(0.231475216, 0.768524783);
List<Double> expected = Arrays.asList(0.768524783, 0.231475216);
double eps = 0.000001;
List<Double> probabilities = ensemble.classificationProbability(featureMap);
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
}
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.3100255188, 0.689974481);
probabilities = ensemble.classificationProbability(featureMap);
expected = Arrays.asList(0.689974481, 0.3100255188);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
}
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
expected = Arrays.asList(0.231475216, 0.768524783);
probabilities = ensemble.classificationProbability(featureMap);
expected = Arrays.asList(0.768524783, 0.231475216);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
}
// This should handle missing values and take the default_left path
@ -270,9 +298,10 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
put("bar", null);
}};
expected = Arrays.asList(0.6899744811, 0.3100255188);
probabilities = ensemble.classificationProbability(featureMap);
probabilities =
((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expected.size(); i++) {
assertThat(probabilities.get(i), closeTo(expected.get(i), eps));
assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps));
}
}
@ -292,7 +321,9 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
@ -302,6 +333,7 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
@ -312,31 +344,36 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0)))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
featureVector = Arrays.asList(0.0, 1.0);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(0.0, ensemble.infer(featureMap), 0.00001);
assertThat(0.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001));
}
public void testRegressionInference() {
@ -370,16 +407,18 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2))
.setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5)))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5}))
.build();
List<Double> featureVector = Arrays.asList(0.4, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.9, ensemble.infer(featureMap), 0.00001);
assertThat(0.9,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(0.5, ensemble.infer(featureMap), 0.00001);
assertThat(0.5,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
// Test with NO aggregator supplied, verifies default behavior of non-weighted sum
ensemble = Ensemble.builder()
@ -390,17 +429,32 @@ public class EnsembleTests extends AbstractSerializingTestCase<Ensemble> {
featureVector = Arrays.asList(0.4, 0.0);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
featureVector = Arrays.asList(2.0, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(1.0, ensemble.infer(featureMap), 0.00001);
assertThat(1.0,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(1.8, ensemble.infer(featureMap), 0.00001);
assertThat(1.8,
closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001));
}
public void testOperationsEstimations() {
Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2);
Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
Tree tree3 = TreeTests.buildRandomTree(Arrays.asList("foo", "baz"), 3);
Ensemble ensemble = Ensemble.builder().setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(Arrays.asList("foo", "bar", "baz"))
.setOutputAggregator(new LogisticRegression(new double[]{0.1, 0.4, 1.0}))
.build();
assertThat(ensemble.estimatedNumOperations(), equalTo(9L));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {

View File

@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticRegression> {
@Override
LogisticRegression createTestInstance(int numberOfWeights) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
return new LogisticRegression(weights);
}
@ -41,13 +42,13 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
}
public void testAggregate() {
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
LogisticRegression logisticRegression = new LogisticRegression(ones);
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
List<Double> variedWeights = Arrays.asList(.01, -1.0, .1, 0.0, 0.0);
double[] variedWeights = new double[]{.01, -1.0, .1, 0.0, 0.0};
logisticRegression = new LogisticRegression(variedWeights);
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(0.0));
@ -56,4 +57,9 @@ public class LogisticRegressionTests extends WeightedAggregatorTests<LogisticReg
assertThat(logisticRegression.aggregate(logisticRegression.processValues(values)), equalTo(1.0));
}
public void testCompatibleWith() {
LogisticRegression logisticRegression = createTestInstance();
assertThat(logisticRegression.compatibleWith(TargetType.CLASSIFICATION), is(true));
assertThat(logisticRegression.compatibleWith(TargetType.REGRESSION), is(true));
}
}

View File

@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
@Override
WeightedMode createTestInstance(int numberOfWeights) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
return new WeightedMode(weights);
}
@ -41,13 +42,13 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
}
public void testAggregate() {
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
WeightedMode weightedMode = new WeightedMode(ones);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
weightedMode = new WeightedMode(variedWeights);
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(5.0));
@ -55,4 +56,10 @@ public class WeightedModeTests extends WeightedAggregatorTests<WeightedMode> {
weightedMode = new WeightedMode();
assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0));
}
public void testCompatibleWith() {
WeightedMode weightedMode = createTestInstance();
assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true));
assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true));
}
}

View File

@ -8,20 +8,21 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
@Override
WeightedSum createTestInstance(int numberOfWeights) {
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList());
double[] weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).mapToDouble(Double::valueOf).toArray();
return new WeightedSum(weights);
}
@ -41,13 +42,13 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
}
public void testAggregate() {
List<Double> ones = Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0);
double[] ones = new double[]{1.0, 1.0, 1.0, 1.0, 1.0};
List<Double> values = Arrays.asList(1.0, 2.0, 2.0, 3.0, 5.0);
WeightedSum weightedSum = new WeightedSum(ones);
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
List<Double> variedWeights = Arrays.asList(1.0, -1.0, .5, 1.0, 5.0);
double[] variedWeights = new double[]{1.0, -1.0, .5, 1.0, 5.0};
weightedSum = new WeightedSum(variedWeights);
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(28.0));
@ -55,4 +56,10 @@ public class WeightedSumTests extends WeightedAggregatorTests<WeightedSum> {
weightedSum = new WeightedSum();
assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0));
}
public void testCompatibleWith() {
WeightedSum weightedSum = createTestInstance();
assertThat(weightedSum.compatibleWith(TargetType.CLASSIFICATION), is(false));
assertThat(weightedSum.compatibleWith(TargetType.REGRESSION), is(true));
}
}

View File

@ -10,6 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.junit.Before;
@ -104,6 +108,19 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
return Tree::new;
}
public void testInferWithStump() {
Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION);
builder.setRoot(TreeNode.builder(0).setLeafValue(42.0));
builder.setFeatureNames(Collections.emptyList());
Tree tree = builder.build();
List<String> featureNames = Arrays.asList("foo", "bar");
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector); // does not really matter as this is a stump
assertThat(42.0,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
}
public void testInfer() {
// Build a tree with 2 nodes and 3 leaves using 2 features
// The leaves have unique values 0.1, 0.2, 0.3
@ -120,26 +137,36 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001));
assertThat(0.3,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001));
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
// This should still work if the internal values are strings
List<String> featureVectorStrings = Arrays.asList("0.3", "0.9");
featureMap = zipObjMap(featureNames, featureVectorStrings);
assertThat(0.2,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001));
assertThat(0.1,
closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001));
}
public void testTreeClassificationProbability() {
@ -153,31 +180,43 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
builder.addLeaf(leftChildNode.getRightChild(), 0.0);
List<String> featureNames = Arrays.asList("foo", "bar");
Tree tree = builder.setFeatureNames(featureNames).build();
Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build();
double eps = 0.000001;
// This feature vector should hit the right child of the root node
List<Double> featureVector = Arrays.asList(0.6, 0.0);
List<Double> expectedProbs = Arrays.asList(1.0, 0.0);
List<String> expectedFields = Arrays.asList("dog", "cat");
Map<String, Object> featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
List<ClassificationInferenceResults.TopClassEntry> probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should hit the left child of the left child of the root node
// i.e. it takes the path left, left
featureVector = Arrays.asList(0.3, 0.7);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap));
// This should hit the right child of the left child of the root node
// i.e. it takes the path left, right
featureVector = Arrays.asList(0.3, 0.9);
featureMap = zipObjMap(featureNames, featureVector);
assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap));
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
// This should handle missing values and take the default_left path
featureMap = new HashMap<String, Object>(2) {{
put("foo", 0.3);
put("bar", null);
}};
assertEquals(1.0, tree.infer(featureMap), 0.00001);
probabilities =
((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses();
for(int i = 0; i < expectedProbs.size(); i++) {
assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps));
assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i)));
}
}
public void testTreeWithNullRoot() {
@ -261,7 +300,12 @@ public class TreeTests extends AbstractSerializingTestCase<Tree> {
assertThat(ex.getMessage(), equalTo(msg));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<Double> values) {
public void testOperationsEstimations() {
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
assertThat(tree.estimatedNumOperations(), equalTo(7L));
}
private static Map<String, Object> zipObjMap(List<String> keys, List<? extends Object> values) {
return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get));
}
}

View File

@ -6,19 +6,16 @@
package org.elasticsearch.xpack.core.ml.notifications;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.common.notifications.Level;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import java.util.Date;
import static org.hamcrest.Matchers.equalTo;
public class AnomalyDetectionAuditMessageTests extends AuditMessageTests<AnomalyDetectionAuditMessage> {
public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase<AnomalyDetectionAuditMessage> {
public void testGetJobType() {
AnomalyDetectionAuditMessage message = createTestInstance();
assertThat(message.getJobType(), equalTo(Job.ANOMALY_DETECTOR_JOB_TYPE));
@Override
public String getJobType() {
return Job.ANOMALY_DETECTOR_JOB_TYPE;
}
@Override
@ -26,11 +23,6 @@ public class AnomalyDetectionAuditMessageTests extends AbstractXContentTestCase<
return AnomalyDetectionAuditMessage.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected AnomalyDetectionAuditMessage createTestInstance() {
return new AnomalyDetectionAuditMessage(

View File

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.notifications;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import static org.hamcrest.Matchers.equalTo;
public abstract class AuditMessageTests<T extends AbstractAuditMessage> extends AbstractXContentTestCase<T> {
public abstract String getJobType();
public void testGetJobType() {
AbstractAuditMessage message = createTestInstance();
assertThat(message.getJobType(), equalTo(getJobType()));
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
}

View File

@ -6,30 +6,22 @@
package org.elasticsearch.xpack.core.ml.notifications;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractXContentTestCase;
import org.elasticsearch.xpack.core.common.notifications.Level;
import java.util.Date;
import static org.hamcrest.Matchers.equalTo;
public class DataFrameAnalyticsAuditMessageTests extends AuditMessageTests<DataFrameAnalyticsAuditMessage> {
public class DataFrameAnalyticsAuditMessageTests extends AbstractXContentTestCase<DataFrameAnalyticsAuditMessage> {
public void testGetJobType() {
DataFrameAnalyticsAuditMessage message = createTestInstance();
assertThat(message.getJobType(), equalTo("data_frame_analytics"));
@Override
public String getJobType() {
return "data_frame_analytics";
}
@Override
protected DataFrameAnalyticsAuditMessage doParseInstance(XContentParser parser) {
return DataFrameAnalyticsAuditMessage.PARSER.apply(parser, null);
}
@Override
protected boolean supportsUnknownFields() {
return true;
}
@Override
protected DataFrameAnalyticsAuditMessage createTestInstance() {
return new DataFrameAnalyticsAuditMessage(

View File

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.notifications;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.common.notifications.Level;
import java.util.Date;
public class InferenceAuditMessageTests extends AuditMessageTests<InferenceAuditMessage> {
@Override
public String getJobType() {
return "inference";
}
@Override
protected InferenceAuditMessage doParseInstance(XContentParser parser) {
return InferenceAuditMessage.PARSER.apply(parser, null);
}
@Override
protected InferenceAuditMessage createTestInstance() {
return new InferenceAuditMessage(
randomBoolean() ? null : randomAlphaOfLength(10),
randomAlphaOfLengthBetween(1, 20),
randomFrom(Level.values()),
new Date(),
randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)
);
}
}

View File

@ -126,6 +126,13 @@ integTest.runner {
'ml/filter_crud/Test get all filter given index exists but no mapping for filter_id',
'ml/get_datafeed_stats/Test get datafeed stats given missing datafeed_id',
'ml/get_datafeeds/Test get datafeed given missing datafeed_id',
'ml/inference_crud/Test delete given used trained model',
'ml/inference_crud/Test delete given unused trained model',
'ml/inference_crud/Test delete with missing model',
'ml/inference_crud/Test get given missing trained model',
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
'ml/jobs_crud/Test cannot create job with existing categorizer state document',
'ml/jobs_crud/Test cannot create job with existing quantiles document',
'ml/jobs_crud/Test cannot create job with existing result document',

View File

@ -0,0 +1,544 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
@Before
public void createBothModels() {
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId("test_classification")
.setSource(CLASSIFICATION_CONFIG, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinition.docId("test_classification"))
.setSource(CLASSIFICATION_DEFINITION, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId("test_regression")
.setSource(REGRESSION_CONFIG, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinition.docId("test_regression"))
.setSource(REGRESSION_DEFINITION, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
}
public void testPipelineCreationAndDeletion() throws Exception {
for (int i = 0; i < 10; i++) {
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_classification_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_regression_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
is(true));
}
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
for (int i = 0; i < 10; i++) {
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(generateSourceDoc())
.setPipeline("simple_classification_pipeline")
.get();
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(generateSourceDoc())
.setPipeline("simple_regression_pipeline")
.get();
}
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
is(true));
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
is(true));
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
.source(new SearchSourceBuilder()
.size(0)
.trackTotalHits(true)
.query(QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("regression_value"))))).get().getHits().getTotalHits().value,
equalTo(20L));
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
.source(new SearchSourceBuilder()
.size(0)
.trackTotalHits(true)
.query(QueryBuilders.boolQuery()
.filter(QueryBuilders.existsQuery("result_class"))))).get().getHits().getTotalHits().value,
equalTo(20L));
}
public void testSimulate() {
String source = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class\",\n" +
" \"inference_config\": {\"classification\":{}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class_prob\",\n" +
" \"inference_config\": {\"classification\": {\"num_top_classes\":2}},\n" +
" \"model_id\": \"test_classification\",\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"test_regression\",\n" +
" \"inference_config\": {\"regression\":{}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"col1\": \"female\",\n" +
" \"col2\": \"M\",\n" +
" \"col3\": \"none\",\n" +
" \"col4\": 10\n" +
" }}]\n" +
"}";
SimulatePipelineResponse response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
assertThat(baseResult.getIngestDocument().getFieldValue("regression_value", Double.class), equalTo(1.0));
assertThat(baseResult.getIngestDocument().getFieldValue("result_class", String.class), equalTo("second"));
assertThat(baseResult.getIngestDocument().getFieldValue("result_class_prob", List.class).size(), equalTo(2));
String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class\",\n" +
" \"model_id\": \"test_classification_missing\",\n" +
" \"inference_config\": {\"classification\":{}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"col1\": \"female\",\n" +
" \"col2\": \"M\",\n" +
" \"col3\": \"none\",\n" +
" \"col4\": 10\n" +
" }}]\n" +
"}";
response = client().admin().cluster()
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get();
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(),
containsString("Could not find trained model [test_classification_missing]"));
}
private Map<String, Object> generateSourceDoc() {
return new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}};
}
private static final String REGRESSION_DEFINITION = "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_sum\": {\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"regression\",\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"model_id\": \"test_regression\"\n" +
"}";
private static final String REGRESSION_CONFIG = "{" +
" \"model_id\": \"test_regression\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0" +
"}";
private static final String CLASSIFICATION_DEFINITION = "{" +
" \"preprocessors\": [\n" +
" {\n" +
" \"one_hot_encoding\": {\n" +
" \"field\": \"col1\",\n" +
" \"hot_map\": {\n" +
" \"male\": \"col1_male\",\n" +
" \"female\": \"col1_female\"\n" +
" }\n" +
" }\n" +
" },\n" +
" {\n" +
" \"target_mean_encoding\": {\n" +
" \"field\": \"col2\",\n" +
" \"feature_name\": \"col2_encoded\",\n" +
" \"target_map\": {\n" +
" \"S\": 5.0,\n" +
" \"M\": 10.0,\n" +
" \"L\": 20\n" +
" },\n" +
" \"default_value\": 5.0\n" +
" }\n" +
" },\n" +
" {\n" +
" \"frequency_encoding\": {\n" +
" \"field\": \"col3\",\n" +
" \"feature_name\": \"col3_encoded\",\n" +
" \"frequency_map\": {\n" +
" \"none\": 0.75,\n" +
" \"true\": 0.10,\n" +
" \"false\": 0.15\n" +
" }\n" +
" }\n" +
" }\n" +
" ],\n" +
" \"trained_model\": {\n" +
" \"ensemble\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"aggregate_output\": {\n" +
" \"weighted_mode\": {\n" +
" \"weights\": [\n" +
" 0.5,\n" +
" 0.5\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"target_type\": \"classification\",\n" +
" \"classification_labels\": [\"first\", \"second\"],\n" +
" \"trained_models\": [\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" },\n" +
" {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"model_id\": \"test_classification\"\n" +
"}";
private static final String CLASSIFICATION_CONFIG = "" +
"{\n" +
" \"model_id\": \"test_classification\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
"}";
private static final String CLASSIFICATION_PIPELINE = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"result_class\",\n" +
" \"model_id\": \"test_classification\",\n" +
" \"inference_config\": {\"classification\": {}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
private static final String REGRESSION_PIPELINE = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"test_regression\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
}

View File

@ -0,0 +1,221 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.apache.http.util.EntityUtils;
import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
import org.junit.After;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
public class TrainedModelIT extends ESRestTestCase {
private static final String BASIC_AUTH_VALUE = basicAuthHeaderValue("x_pack_rest_user",
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
@Override
protected Settings restClientSettings() {
return Settings.builder().put(super.restClientSettings()).put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE).build();
}
@Override
protected boolean preserveTemplatesUponCompletion() {
return true;
}
public void testGetTrainedModels() throws IOException {
String modelId = "test_regression_model";
String modelId2 = "test_regression_model-2";
Request model1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
Request model2 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
model2.setJsonEntity(buildRegressionModel(modelId2));
assertThat(client().performRequest(model2).getStatusLine().getStatusCode(), equalTo(201));
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
Response getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/" + modelId));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
String response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"count\":1"));
getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression*"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
assertThat(response, not(containsString("\"definition\"")));
assertThat(response, containsString("\"count\":2"));
getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"heap_memory_estimation_bytes\""));
assertThat(response, containsString("\"heap_memory_estimation\""));
assertThat(response, containsString("\"definition\""));
assertThat(response, containsString("\"count\":1"));
ResponseException responseException = expectThrows(ResponseException.class, () ->
client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));
getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
assertThat(response, containsString("\"count\":2"));
getModel = client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=true"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"count\":0"));
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET",
MachineLearning.BASE_PATH + "inference/classification*?allow_no_match=false")));
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(404));
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=0&size=1"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"count\":2"));
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\"")));
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
response = EntityUtils.toString(getModel.getEntity());
assertThat(response, containsString("\"count\":2"));
assertThat(response, not(containsString("\"model_id\":\"test_regression_model\"")));
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
}
public void testDeleteTrainedModels() throws IOException {
String modelId = "test_delete_regression_model";
Request model1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
model1.setJsonEntity(buildRegressionModel(modelId));
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
Request modelDefinition1 = new Request("PUT",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
Response delModel = client().performRequest(new Request("DELETE",
MachineLearning.BASE_PATH + "inference/" + modelId));
String response = EntityUtils.toString(delModel.getEntity());
assertThat(response, containsString("\"acknowledged\":true"));
ResponseException responseException = expectThrows(ResponseException.class,
() -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId)));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
responseException = expectThrows(ResponseException.class,
() -> client().performRequest(
new Request("GET",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId))));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
responseException = expectThrows(ResponseException.class,
() -> client().performRequest(
new Request("GET",
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId)));
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
}
private static String buildRegressionModel(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
TrainedModelConfig.builder()
.setModelId(modelId)
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
.setCreatedBy("ml_test")
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
.build()
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
}
}
private static String buildRegressionModelDefinition(String modelId) throws IOException {
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
new TrainedModelDefinition.Builder()
.setPreProcessors(Collections.emptyList())
.setTrainedModel(LocalModelTests.buildRegression())
.setModelId(modelId)
.build()
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
}
}
@After
public void clearMlState() throws Exception {
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
ESRestTestCase.waitForPendingTasks(adminClient());
}
}

View File

@ -43,12 +43,14 @@ import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.indices.analysis.AnalysisModule.AnalysisProvider;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.monitor.os.OsProbe;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.persistent.PersistentTasksExecutor;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.AnalysisPlugin;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.plugins.PersistentTaskPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
@ -72,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction;
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
@ -93,7 +96,10 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
@ -139,6 +145,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteFilterAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.TransportDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.action.TransportEstimateMemoryUsageAction;
import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction;
import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction;
@ -160,6 +167,9 @@ import org.elasticsearch.xpack.ml.action.TransportGetJobsStatsAction;
import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction;
import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction;
import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction;
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.action.TransportInferModelAction;
import org.elasticsearch.xpack.ml.action.TransportGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction;
import org.elasticsearch.xpack.ml.action.TransportKillProcessAction;
import org.elasticsearch.xpack.ml.action.TransportMlInfoAction;
@ -199,6 +209,8 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactor
import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.job.JobManager;
@ -221,6 +233,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory;
import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
@ -257,6 +270,9 @@ import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction;
import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction;
import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction;
import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction;
@ -297,7 +313,7 @@ import java.util.function.UnaryOperator;
import static java.util.Collections.emptyList;
import static org.elasticsearch.index.mapper.MapperService.SINGLE_MAPPING_NAME;
public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, PersistentTaskPlugin {
public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlugin, IngestPlugin, PersistentTaskPlugin {
public static final String NAME = "ml";
public static final String BASE_PATH = "/_ml/";
public static final String PRE_V7_BASE_PATH = "/_xpack/ml/";
@ -321,6 +337,22 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
};
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
if (this.enabled == false) {
return Collections.emptyMap();
}
InferenceProcessor.Factory inferenceFactory = new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
this.settings,
parameters.ingestService,
getLicenseState());
getLicenseState().addListener(inferenceFactory);
parameters.ingestService.addIngestClusterStateListener(inferenceFactory);
return Collections.singletonMap(InferenceProcessor.TYPE, inferenceFactory);
}
@Override
public Set<DiscoveryNodeRole> getRoles() {
return Collections.singleton(ML_ROLE);
@ -401,18 +433,21 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
public List<Setting<?>> getSettings() {
return Collections.unmodifiableList(
Arrays.asList(MachineLearningField.AUTODETECT_PROCESS,
PROCESS_CONNECT_TIMEOUT,
ML_ENABLED,
CONCURRENT_JOB_ALLOCATIONS,
MachineLearningField.MAX_MODEL_MEMORY_LIMIT,
MAX_LAZY_ML_NODES,
MAX_MACHINE_MEMORY_PERCENT,
AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING,
AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC,
MAX_OPEN_JOBS_PER_NODE,
MIN_DISK_SPACE_OFF_HEAP,
MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION));
Arrays.asList(MachineLearningField.AUTODETECT_PROCESS,
PROCESS_CONNECT_TIMEOUT,
ML_ENABLED,
CONCURRENT_JOB_ALLOCATIONS,
MachineLearningField.MAX_MODEL_MEMORY_LIMIT,
MAX_LAZY_ML_NODES,
MAX_MACHINE_MEMORY_PERCENT,
AutodetectBuilder.DONT_PERSIST_MODEL_STATE_SETTING,
AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC,
MAX_OPEN_JOBS_PER_NODE,
MIN_DISK_SPACE_OFF_HEAP,
MlConfigMigrationEligibilityCheck.ENABLE_CONFIG_MIGRATION,
InferenceProcessor.MAX_INFERENCE_PROCESSORS,
ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE,
ModelLoadingService.INFERENCE_MODEL_CACHE_TTL));
}
public Settings additionalSettings() {
@ -483,6 +518,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
AnomalyDetectionAuditor anomalyDetectionAuditor = new AnomalyDetectionAuditor(client, clusterService.getNodeName());
DataFrameAnalyticsAuditor dataFrameAnalyticsAuditor = new DataFrameAnalyticsAuditor(client, clusterService.getNodeName());
InferenceAuditor inferenceAuditor = new InferenceAuditor(client, clusterService.getNodeName());
this.dataFrameAnalyticsAuditor.set(dataFrameAnalyticsAuditor);
JobResultsProvider jobResultsProvider = new JobResultsProvider(client, settings);
JobResultsPersister jobResultsPersister = new JobResultsPersister(client);
@ -569,7 +605,12 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
this.datafeedManager.set(datafeedManager);
// Inference components
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry);
final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
inferenceAuditor,
threadPool,
clusterService,
settings);
// Data frame analytics components
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
@ -614,12 +655,14 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
datafeedManager,
anomalyDetectionAuditor,
dataFrameAnalyticsAuditor,
inferenceAuditor,
mlAssignmentNotifier,
memoryTracker,
analyticsProcessManager,
memoryEstimationProcessManager,
dataFrameAnalyticsConfigProvider,
nativeStorageProvider,
modelLoadingService,
trainedModelProvider
);
}
@ -717,7 +760,10 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new RestStartDataFrameAnalyticsAction(restController),
new RestStopDataFrameAnalyticsAction(restController),
new RestEvaluateDataFrameAction(restController),
new RestEstimateMemoryUsageAction(restController)
new RestEstimateMemoryUsageAction(restController),
new RestGetTrainedModelsAction(restController),
new RestDeleteTrainedModelAction(restController),
new RestGetTrainedModelsStatsAction(restController)
);
}
@ -784,7 +830,11 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu
new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class),
new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class),
new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class),
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class)
new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class),
new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class),
new ActionHandler<>(GetTrainedModelsAction.INSTANCE, TransportGetTrainedModelsAction.class),
new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class),
new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class)
);
}

View File

@ -10,6 +10,12 @@ import org.apache.lucene.util.Constants;
import org.apache.lucene.util.Counter;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
@ -18,8 +24,10 @@ import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.env.Environment;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.Platforms;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.XPackPlugin;
@ -31,11 +39,13 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.NativeControllerHolder;
@ -44,11 +54,14 @@ import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
@ -135,10 +148,10 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
@Override
public void usage(ActionListener<XPackFeatureSet.Usage> listener) {
ClusterState state = clusterService.state();
new Retriever(client, jobManagerHolder, available(), enabled(), mlNodeCount(state)).execute(listener);
new Retriever(client, jobManagerHolder, available(), enabled(), state).execute(listener);
}
private int mlNodeCount(final ClusterState clusterState) {
private static int mlNodeCount(boolean enabled, final ClusterState clusterState) {
if (enabled == false) {
return 0;
}
@ -161,9 +174,11 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
private Map<String, Object> jobsUsage;
private Map<String, Object> datafeedsUsage;
private Map<String, Object> analyticsUsage;
private Map<String, Object> inferenceUsage;
private final ClusterState state;
private int nodeCount;
public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, int nodeCount) {
public Retriever(Client client, JobManagerHolder jobManagerHolder, boolean available, boolean enabled, ClusterState state) {
this.client = Objects.requireNonNull(client);
this.jobManagerHolder = jobManagerHolder;
this.available = available;
@ -171,61 +186,15 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
this.jobsUsage = new LinkedHashMap<>();
this.datafeedsUsage = new LinkedHashMap<>();
this.analyticsUsage = new LinkedHashMap<>();
this.nodeCount = nodeCount;
this.inferenceUsage = new LinkedHashMap<>();
this.nodeCount = mlNodeCount(enabled, state);
this.state = state;
}
public void execute(ActionListener<Usage> listener) {
// empty holder means either ML disabled or transport client mode
if (jobManagerHolder.isEmpty()) {
listener.onResponse(
new MachineLearningFeatureSetUsage(available,
enabled,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
0));
return;
}
// Step 3. Extract usage from data frame analytics and return usage response
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
response -> {
addDataFrameAnalyticsUsage(response, analyticsUsage);
listener.onResponse(new MachineLearningFeatureSetUsage(available,
enabled,
jobsUsage,
datafeedsUsage,
analyticsUsage,
nodeCount));
},
listener::onFailure
);
// Step 2. Extract usage from datafeeds stats and return usage response
ActionListener<GetDatafeedsStatsAction.Response> datafeedStatsListener =
ActionListener.wrap(response -> {
addDatafeedsUsage(response);
GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest =
new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL);
dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000));
client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener);
},
listener::onFailure);
// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL);
ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(
response -> {
jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> {
addJobsUsage(response, jobs.results());
GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(
GetDatafeedsStatsAction.ALL);
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener);
}, listener::onFailure));
}, listener::onFailure);
// Step 0. Kick off the chain of callbacks by requesting jobs stats
client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener);
private static void initializeStats(Map<String, Long> emptyStatsMap) {
emptyStatsMap.put("sum", 0L);
emptyStatsMap.put("min", 0L);
emptyStatsMap.put("max", 0L);
}
private void addJobsUsage(GetJobsStatsAction.Response response, List<Job> jobs) {
@ -322,7 +291,7 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
}
private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsStatsAction.Response response,
Map<String, Object> dataframeAnalyticsUsage) {
Map<String, Object> dataframeAnalyticsUsage) {
Map<DataFrameAnalyticsState, Counter> dataFrameAnalyticsStateCounterMap = new HashMap<>();
for(GetDataFrameAnalyticsStatsAction.Response.Stats stats : response.getResponse().results()) {
@ -334,5 +303,149 @@ public class MachineLearningFeatureSet implements XPackFeatureSet {
createCountUsageEntry(dataFrameAnalyticsStateCounterMap.get(state).get()));
}
}
private static void updateStats(Map<String, Long> statsMap, Long value) {
statsMap.compute("sum", (k, v) -> v + value);
statsMap.compute("min", (k, v) -> Math.min(v, value));
statsMap.compute("max", (k, v) -> Math.max(v, value));
}
private static String[] ingestNodes(final ClusterState clusterState) {
String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()];
Iterator<String> nodeIterator = clusterState.nodes().getIngestNodes().keysIt();
int i = 0;
while (nodeIterator.hasNext()) {
ingestNodes[i++] = nodeIterator.next();
}
return ingestNodes;
}
public void execute(ActionListener<Usage> listener) {
// empty holder means either ML disabled or transport client mode
if (jobManagerHolder.isEmpty()) {
listener.onResponse(
new MachineLearningFeatureSetUsage(available,
enabled,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap(),
0));
return;
}
// Step 5. extract trained model config count and then return results
ActionListener<SearchResponse> trainedModelConfigCountListener = ActionListener.wrap(
response -> {
addTrainedModelStats(response, inferenceUsage);
MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(available,
enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount);
listener.onResponse(usage);
},
listener::onFailure
);
// Step 4. Extract usage from ingest statistics and gather trained model config count
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
response -> {
addInferenceIngestUsage(response, inferenceUsage);
SearchRequestBuilder requestBuilder = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.setSize(0)
.setTrackTotalHits(true);
ClientHelper.executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ClientHelper.ML_ORIGIN,
requestBuilder.request(),
trainedModelConfigCountListener,
client::search);
},
listener::onFailure
);
// Step 3. Extract usage from data frame analytics stats and then request ingest node stats
ActionListener<GetDataFrameAnalyticsStatsAction.Response> dataframeAnalyticsListener = ActionListener.wrap(
response -> {
addDataFrameAnalyticsUsage(response, analyticsUsage);
String[] ingestNodes = ingestNodes(state);
NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true);
client.execute(NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener);
},
listener::onFailure
);
// Step 2. Extract usage from datafeeds stats and return usage response
ActionListener<GetDatafeedsStatsAction.Response> datafeedStatsListener =
ActionListener.wrap(response -> {
addDatafeedsUsage(response);
GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest =
new GetDataFrameAnalyticsStatsAction.Request(GetDatafeedsStatsAction.ALL);
dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10_000));
client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, dataframeAnalyticsStatsRequest, dataframeAnalyticsListener);
},
listener::onFailure);
// Step 1. Extract usage from jobs stats and then request stats for all datafeeds
GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(MetaData.ALL);
ActionListener<GetJobsStatsAction.Response> jobStatsListener = ActionListener.wrap(
response -> {
jobManagerHolder.getJobManager().expandJobs(MetaData.ALL, true, ActionListener.wrap(jobs -> {
addJobsUsage(response, jobs.results());
GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(
GetDatafeedsStatsAction.ALL);
client.execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest, datafeedStatsListener);
}, listener::onFailure));
}, listener::onFailure);
// Step 0. Kick off the chain of callbacks by requesting jobs stats
client.execute(GetJobsStatsAction.INSTANCE, jobStatsRequest, jobStatsListener);
}
//TODO separate out ours and users models possibly regression vs classification
private void addTrainedModelStats(SearchResponse response, Map<String, Object> inferenceUsage) {
inferenceUsage.put("trained_models",
Collections.singletonMap(MachineLearningFeatureSetUsage.ALL,
createCountUsageEntry(response.getHits().getTotalHits().value)));
}
//TODO separate out ours and users models possibly regression vs classification
private void addInferenceIngestUsage(NodesStatsResponse response, Map<String, Object> inferenceUsage) {
Set<String> pipelines = new HashSet<>();
Map<String, Long> docCountStats = new HashMap<>(3);
Map<String, Long> timeStats = new HashMap<>(3);
Map<String, Long> failureStats = new HashMap<>(3);
initializeStats(docCountStats);
initializeStats(timeStats);
initializeStats(failureStats);
response.getNodes()
.stream()
.map(NodeStats::getIngestStats)
.map(IngestStats::getProcessorStats)
.forEach(map ->
map.forEach((pipelineId, processors) -> {
boolean containsInference = false;
for (IngestStats.ProcessorStat stats : processors) {
if (stats.getName().equals(InferenceProcessor.TYPE)) {
containsInference = true;
long ingestCount = stats.getStats().getIngestCount();
long ingestTime = stats.getStats().getIngestTimeInMillis();
long failureCount = stats.getStats().getIngestFailedCount();
updateStats(docCountStats, ingestCount);
updateStats(timeStats, ingestTime);
updateStats(failureStats, failureCount);
}
}
if (containsInference) {
pipelines.add(pipelineId);
}
})
);
Map<String, Object> ingestUsage = new HashMap<>(6);
ingestUsage.put("pipelines", createCountUsageEntry(pipelines.size()));
ingestUsage.put("num_docs_processed", docCountStats);
ingestUsage.put("time_ms", timeStats);
ingestUsage.put("num_failures", failureStats);
inferenceUsage.put("ingest_processors", Collections.singletonMap(MachineLearningFeatureSetUsage.ALL, ingestUsage));
}
}
}

View File

@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
/**
* The action is a master node action to ensure it reads an up-to-date cluster
* state in order to determine if there is a processor referencing the trained model
*/
public class TransportDeleteTrainedModelAction
extends TransportMasterNodeAction<DeleteTrainedModelAction.Request, AcknowledgedResponse> {
private static final Logger LOGGER = LogManager.getLogger(TransportDeleteTrainedModelAction.class);
private final TrainedModelProvider trainedModelProvider;
private final InferenceAuditor auditor;
private final IngestService ingestService;
@Inject
public TransportDeleteTrainedModelAction(TransportService transportService, ClusterService clusterService,
ThreadPool threadPool, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
TrainedModelProvider configProvider, InferenceAuditor auditor,
IngestService ingestService) {
super(DeleteTrainedModelAction.NAME, transportService, clusterService, threadPool, actionFilters,
DeleteTrainedModelAction.Request::new, indexNameExpressionResolver);
this.trainedModelProvider = configProvider;
this.ingestService = ingestService;
this.auditor = Objects.requireNonNull(auditor);
}
@Override
protected String executor() {
return ThreadPool.Names.SAME;
}
@Override
protected AcknowledgedResponse read(StreamInput in) throws IOException {
return new AcknowledgedResponse(in);
}
@Override
protected void masterOperation(DeleteTrainedModelAction.Request request,
ClusterState state,
ActionListener<AcknowledgedResponse> listener) {
String id = request.getId();
IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE);
Set<String> referencedModels = getReferencedModelKeys(currentIngestMetadata);
if (referencedModels.contains(id)) {
listener.onFailure(new ElasticsearchStatusException("Cannot delete model [{}] as it is still referenced by ingest processors",
RestStatus.CONFLICT,
id));
return;
}
trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap(
r -> {
auditor.info(request.getId(), "trained model deleted");
listener.onResponse(new AcknowledgedResponse(true));
},
listener::onFailure
));
}
private Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
Set<String> allReferencedModelKeys = new HashSet<>();
if (ingestMetadata == null) {
return allReferencedModelKeys;
}
for(Map.Entry<String, PipelineConfiguration> entry : ingestMetadata.getPipelines().entrySet()) {
String pipelineId = entry.getKey();
Map<String, Object> config = entry.getValue().getConfigAsMap();
try {
Pipeline pipeline = Pipeline.create(pipelineId,
config,
ingestService.getProcessorFactories(),
ingestService.getScriptService());
pipeline.getProcessors().stream()
.filter(p -> p instanceof InferenceProcessor)
.map(p -> (InferenceProcessor) p)
.map(InferenceProcessor::getModelId)
.forEach(allReferencedModelKeys::add);
} catch (Exception ex) {
LOGGER.warn(new ParameterizedMessage("failed to load pipeline [{}]", pipelineId), ex);
}
}
return allReferencedModelKeys;
}
@Override
protected ClusterBlockException checkBlock(DeleteTrainedModelAction.Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
}
}

View File

@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.util.Collections;
import java.util.Set;
public class TransportGetTrainedModelsAction extends HandledTransportAction<Request, Response> {
private final TrainedModelProvider provider;
@Inject
public TransportGetTrainedModelsAction(TransportService transportService,
ActionFilters actionFilters,
TrainedModelProvider trainedModelProvider) {
super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new);
this.provider = trainedModelProvider;
}
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
Response.Builder responseBuilder = Response.builder();
ActionListener<Tuple<Long, Set<String>>> idExpansionListener = ActionListener.wrap(
totalAndIds -> {
responseBuilder.setTotalCount(totalAndIds.v1());
if (totalAndIds.v2().isEmpty()) {
listener.onResponse(responseBuilder.build());
return;
}
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
listener.onFailure(
ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)
);
return;
}
if (request.isIncludeModelDefinition()) {
provider.getTrainedModel(totalAndIds.v2().iterator().next(), true, ActionListener.wrap(
config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()),
listener::onFailure
));
} else {
provider.getTrainedModels(totalAndIds.v2(), request.isAllowNoResources(), ActionListener.wrap(
configs -> listener.onResponse(responseBuilder.setModels(configs).build()),
listener::onFailure
));
}
},
listener::onFailure
);
provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idExpansionListener);
}
}

View File

@ -0,0 +1,249 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request,
GetTrainedModelsStatsAction.Response> {
private final Client client;
private final ClusterService clusterService;
private final IngestService ingestService;
private final TrainedModelProvider trainedModelProvider;
@Inject
public TransportGetTrainedModelsStatsAction(TransportService transportService,
ActionFilters actionFilters,
ClusterService clusterService,
IngestService ingestService,
TrainedModelProvider trainedModelProvider,
Client client) {
super(GetTrainedModelsStatsAction.NAME, transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
this.client = client;
this.clusterService = clusterService;
this.ingestService = ingestService;
this.trainedModelProvider = trainedModelProvider;
}
@Override
protected void doExecute(Task task,
GetTrainedModelsStatsAction.Request request,
ActionListener<GetTrainedModelsStatsAction.Response> listener) {
GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
ActionListener<NodesStatsResponse> nodesStatsListener = ActionListener.wrap(
nodesStatsResponse -> {
Map<String, IngestStats> modelIdIngestStats = inferenceIngestStatsByPipelineId(nodesStatsResponse,
pipelineIdsByModelIds(clusterService.state(),
ingestService,
responseBuilder.getExpandedIds()));
listener.onResponse(responseBuilder.setIngestStatsByModelId(modelIdIngestStats).build());
},
listener::onFailure
);
ActionListener<Tuple<Long, Set<String>>> idsListener = ActionListener.wrap(
tuple -> {
responseBuilder.setExpandedIds(tuple.v2())
.setTotalModelCount(tuple.v1());
String[] ingestNodes = ingestNodes(clusterService.state());
NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().ingest(true);
executeAsyncWithOrigin(client, ML_ORIGIN, NodesStatsAction.INSTANCE, nodesStatsRequest, nodesStatsListener);
},
listener::onFailure
);
trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), idsListener);
}
static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse response,
Map<String, Set<String>> modelIdToPipelineId) {
Map<String, IngestStats> ingestStatsMap = new HashMap<>();
modelIdToPipelineId.forEach((modelId, pipelineIds) -> {
List<IngestStats> collectedStats = response.getNodes()
.stream()
.map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds))
.collect(Collectors.toList());
ingestStatsMap.put(modelId, mergeStats(collectedStats));
});
return ingestStatsMap;
}
static String[] ingestNodes(final ClusterState clusterState) {
String[] ingestNodes = new String[clusterState.nodes().getIngestNodes().size()];
Iterator<String> nodeIterator = clusterState.nodes().getIngestNodes().keysIt();
int i = 0;
while(nodeIterator.hasNext()) {
ingestNodes[i++] = nodeIterator.next();
}
return ingestNodes;
}
static Map<String, Set<String>> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set<String> modelIds) {
IngestMetadata ingestMetadata = state.metaData().custom(IngestMetadata.TYPE);
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
if (ingestMetadata == null) {
return pipelineIdsByModelIds;
}
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
try {
Pipeline pipeline = Pipeline.create(pipelineId,
pipelineConfiguration.getConfigAsMap(),
ingestService.getProcessorFactories(),
ingestService.getScriptService());
pipeline.getProcessors().forEach(processor -> {
if (processor instanceof InferenceProcessor) {
InferenceProcessor inferenceProcessor = (InferenceProcessor) processor;
if (modelIds.contains(inferenceProcessor.getModelId())) {
pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(),
m -> new LinkedHashSet<>()).add(pipelineId);
}
}
});
} catch (Exception ex) {
throw new ElasticsearchException("unexpected failure gathering pipeline information", ex);
}
});
return pipelineIdsByModelIds;
}
static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
IngestStats fullNodeStats = nodeStats.getIngestStats();
Map<String, List<IngestStats.ProcessorStat>> filteredProcessorStats = new HashMap<>(fullNodeStats.getProcessorStats());
filteredProcessorStats.keySet().retainAll(pipelineIds);
List<IngestStats.PipelineStat> filteredPipelineStats = fullNodeStats.getPipelineStats()
.stream()
.filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId()))
.collect(Collectors.toList());
CounterMetric ingestCount = new CounterMetric();
CounterMetric ingestTimeInMillis = new CounterMetric();
CounterMetric ingestCurrent = new CounterMetric();
CounterMetric ingestFailedCount = new CounterMetric();
filteredPipelineStats.forEach(pipelineStat -> {
IngestStats.Stats stats = pipelineStat.getStats();
ingestCount.inc(stats.getIngestCount());
ingestTimeInMillis.inc(stats.getIngestTimeInMillis());
ingestCurrent.inc(stats.getIngestCurrent());
ingestFailedCount.inc(stats.getIngestFailedCount());
});
return new IngestStats(
new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()),
filteredPipelineStats,
filteredProcessorStats);
}
private static IngestStats mergeStats(List<IngestStats> ingestStatsList) {
Map<String, IngestStatsAccumulator> pipelineStatsAcc = new LinkedHashMap<>(ingestStatsList.size());
Map<String, Map<String, IngestStatsAccumulator>> processorStatsAcc = new LinkedHashMap<>(ingestStatsList.size());
IngestStatsAccumulator totalStats = new IngestStatsAccumulator();
ingestStatsList.forEach(ingestStats -> {
ingestStats.getPipelineStats()
.forEach(pipelineStat ->
pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(),
p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats()));
ingestStats.getProcessorStats()
.forEach((pipelineId, processorStat) -> {
Map<String, IngestStatsAccumulator> processorAcc = processorStatsAcc.computeIfAbsent(pipelineId,
k -> new LinkedHashMap<>());
processorStat.forEach(p ->
processorAcc.computeIfAbsent(p.getName(),
k -> new IngestStatsAccumulator(p.getType())).inc(p.getStats()));
});
totalStats.inc(ingestStats.getTotalStats());
});
List<IngestStats.PipelineStat> pipelineStatList = new ArrayList<>(pipelineStatsAcc.size());
pipelineStatsAcc.forEach((pipelineId, accumulator) ->
pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build())));
Map<String, List<IngestStats.ProcessorStat>> processorStatList = new LinkedHashMap<>(processorStatsAcc.size());
processorStatsAcc.forEach((pipelineId, accumulatorMap) -> {
List<IngestStats.ProcessorStat> processorStats = new ArrayList<>(accumulatorMap.size());
accumulatorMap.forEach((processorName, acc) ->
processorStats.add(new IngestStats.ProcessorStat(processorName, acc.type, acc.build())));
processorStatList.put(pipelineId, processorStats);
});
return new IngestStats(totalStats.build(), pipelineStatList, processorStatList);
}
private static class IngestStatsAccumulator {
CounterMetric ingestCount = new CounterMetric();
CounterMetric ingestTimeInMillis = new CounterMetric();
CounterMetric ingestCurrent = new CounterMetric();
CounterMetric ingestFailedCount = new CounterMetric();
String type;
IngestStatsAccumulator() {}
IngestStatsAccumulator(String type) {
this.type = type;
}
IngestStatsAccumulator inc(IngestStats.Stats s) {
ingestCount.inc(s.getIngestCount());
ingestTimeInMillis.inc(s.getIngestTimeInMillis());
ingestCurrent.inc(s.getIngestCurrent());
ingestFailedCount.inc(s.getIngestFailedCount());
return this;
}
IngestStats.Stats build() {
return new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count());
}
}
}

View File

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.ml.inference.loadingservice.Model;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;
public class TransportInferModelAction extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
private final ModelLoadingService modelLoadingService;
private final Client client;
private final XPackLicenseState licenseState;
@Inject
public TransportInferModelAction(TransportService transportService,
ActionFilters actionFilters,
ModelLoadingService modelLoadingService,
Client client,
XPackLicenseState licenseState) {
super(InferModelAction.NAME, transportService, actionFilters, InferModelAction.Request::new);
this.modelLoadingService = modelLoadingService;
this.client = client;
this.licenseState = licenseState;
}
@Override
protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
if (licenseState.isMachineLearningAllowed() == false) {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
return;
}
ActionListener<Model> getModelListener = ActionListener.wrap(
model -> {
TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor =
new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME),
// run through all tasks
r -> true,
// Always fail immediately and return an error
ex -> true);
request.getObjectsToInfer().forEach(stringObjectMap ->
typedChainTaskExecutor.add(chainedTask ->
model.infer(stringObjectMap, request.getConfig(), chainedTask)));
typedChainTaskExecutor.execute(ActionListener.wrap(
inferenceResultsInterfaces ->
listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)),
listener::onFailure
));
},
listener::onFailure
);
this.modelLoadingService.getModel(request.getModelId(), getModelListener);
}
}

View File

@ -150,6 +150,8 @@ public class AnalyticsResultProcessor {
.setMetadata(Collections.singletonMap("analytics_config",
XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true)))
.setDefinition(definition)
.setEstimatedHeapMemory(definition.ramBytesUsed())
.setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations())
.setInput(new TrainedModelInput(fieldNames))
.build();
}

View File

@ -7,14 +7,17 @@ package org.elasticsearch.xpack.ml.dataframe.process.results;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import java.io.IOException;
import java.util.Collections;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
public class AnalyticsResult implements ToXContentObject {
@ -67,7 +70,9 @@ public class AnalyticsResult implements ToXContentObject {
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
}
if (inferenceModel != null) {
builder.field(INFERENCE_MODEL.getPreferredName(), inferenceModel);
builder.field(INFERENCE_MODEL.getPreferredName(),
inferenceModel,
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")));
}
builder.endObject();
return builder;

View File

@ -0,0 +1,297 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.ingest;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.ingest.AbstractProcessor;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.LicenseStateListener;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
public class InferenceProcessor extends AbstractProcessor {
// How many total inference processors are allowed to be used in the cluster.
public static final Setting<Integer> MAX_INFERENCE_PROCESSORS = Setting.intSetting("xpack.ml.max_inference_processors",
50,
1,
Setting.Property.Dynamic,
Setting.Property.NodeScope);
public static final String TYPE = "inference";
public static final String MODEL_ID = "model_id";
public static final String INFERENCE_CONFIG = "inference_config";
public static final String TARGET_FIELD = "target_field";
public static final String FIELD_MAPPINGS = "field_mappings";
public static final String MODEL_INFO_FIELD = "model_info_field";
public static final String INCLUDE_MODEL_METADATA = "include_model_metadata";
private final Client client;
private final String modelId;
private final String targetField;
private final String modelInfoField;
private final InferenceConfig inferenceConfig;
private final Map<String, String> fieldMapping;
private final boolean includeModelMetadata;
public InferenceProcessor(Client client,
String tag,
String targetField,
String modelId,
InferenceConfig inferenceConfig,
Map<String, String> fieldMapping,
String modelInfoField,
boolean includeModelMetadata) {
super(tag);
this.client = ExceptionsHelper.requireNonNull(client, "client");
this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD);
this.modelInfoField = ExceptionsHelper.requireNonNull(modelInfoField, MODEL_INFO_FIELD);
this.includeModelMetadata = includeModelMetadata;
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
this.fieldMapping = ExceptionsHelper.requireNonNull(fieldMapping, FIELD_MAPPINGS);
}
public String getModelId() {
return modelId;
}
@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
executeAsyncWithOrigin(client,
ML_ORIGIN,
InferModelAction.INSTANCE,
this.buildRequest(ingestDocument),
ActionListener.wrap(
r -> {
try {
mutateDocument(r, ingestDocument);
handler.accept(ingestDocument, null);
} catch(ElasticsearchException ex) {
handler.accept(ingestDocument, ex);
}
},
e -> handler.accept(ingestDocument, e)
));
}
InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
Map<String, Object> fields = new HashMap<>(ingestDocument.getSourceAndMetadata());
if (fieldMapping != null) {
fieldMapping.forEach((src, dest) -> {
Object srcValue = fields.remove(src);
if (srcValue != null) {
fields.put(dest, srcValue);
}
});
}
return new InferModelAction.Request(modelId, fields, inferenceConfig);
}
void mutateDocument(InferModelAction.Response response, IngestDocument ingestDocument) {
if (response.getInferenceResults().isEmpty()) {
throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR);
}
response.getInferenceResults().get(0).writeResult(ingestDocument, this.targetField);
if (includeModelMetadata) {
ingestDocument.setFieldValue(modelInfoField + "." + MODEL_ID, modelId);
}
}
@Override
public IngestDocument execute(IngestDocument ingestDocument) {
throw new UnsupportedOperationException("should never be called");
}
@Override
public String getType() {
return TYPE;
}
public static final class Factory implements Processor.Factory, Consumer<ClusterState>, LicenseStateListener {
private static final Logger logger = LogManager.getLogger(Factory.class);
private final Client client;
private final IngestService ingestService;
private final XPackLicenseState licenseState;
private volatile int currentInferenceProcessors;
private volatile int maxIngestProcessors;
private volatile Version minNodeVersion = Version.CURRENT;
private volatile boolean inferenceAllowed;
public Factory(Client client,
ClusterService clusterService,
Settings settings,
IngestService ingestService,
XPackLicenseState licenseState) {
this.client = client;
this.maxIngestProcessors = MAX_INFERENCE_PROCESSORS.get(settings);
this.ingestService = ingestService;
this.licenseState = licenseState;
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_INFERENCE_PROCESSORS, this::setMaxIngestProcessors);
}
@Override
public void accept(ClusterState state) {
minNodeVersion = state.nodes().getMinNodeVersion();
MetaData metaData = state.getMetaData();
if (metaData == null) {
currentInferenceProcessors = 0;
return;
}
IngestMetadata ingestMetadata = metaData.custom(IngestMetadata.TYPE);
if (ingestMetadata == null) {
currentInferenceProcessors = 0;
return;
}
int count = 0;
for (PipelineConfiguration configuration : ingestMetadata.getPipelines().values()) {
try {
Pipeline pipeline = Pipeline.create(configuration.getId(),
configuration.getConfigAsMap(),
ingestService.getProcessorFactories(),
ingestService.getScriptService());
count += pipeline.getProcessors().stream().filter(processor -> processor instanceof InferenceProcessor).count();
} catch (Exception ex) {
logger.warn(new ParameterizedMessage("failure parsing pipeline config [{}]", configuration.getId()), ex);
}
}
currentInferenceProcessors = count;
}
// Used for testing
int numInferenceProcessors() {
return currentInferenceProcessors;
}
@Override
public InferenceProcessor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config)
throws Exception {
if (inferenceAllowed == false) {
throw LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING);
}
if (this.maxIngestProcessors <= currentInferenceProcessors) {
throw new ElasticsearchStatusException("Max number of inference processors reached, total inference processors [{}]. " +
"Adjust the setting [{}]: [{}] if a greater number is desired.",
RestStatus.CONFLICT,
currentInferenceProcessors,
MAX_INFERENCE_PROCESSORS.getKey(),
maxIngestProcessors);
}
boolean includeModelMetadata = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, INCLUDE_MODEL_METADATA, true);
String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID);
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD);
Map<String, String> fieldMapping = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS);
InferenceConfig inferenceConfig = inferenceConfigFromMap(ConfigurationUtils.readMap(TYPE, tag, config, INFERENCE_CONFIG));
String modelInfoField = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_INFO_FIELD, "ml");
// If multiple inference processors are in the same pipeline, it is wise to tag them
// The tag will keep metadata entries from stepping on each other
if (tag != null) {
modelInfoField += "." + tag;
}
return new InferenceProcessor(client,
tag,
targetField,
modelId,
inferenceConfig,
fieldMapping,
modelInfoField,
includeModelMetadata);
}
// Package private for testing
void setMaxIngestProcessors(int maxIngestProcessors) {
logger.trace("updating setting maxIngestProcessors from [{}] to [{}]", this.maxIngestProcessors, maxIngestProcessors);
this.maxIngestProcessors = maxIngestProcessors;
}
InferenceConfig inferenceConfigFromMap(Map<String, Object> inferenceConfig) throws IOException {
ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
if (inferenceConfig.size() != 1) {
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
INFERENCE_CONFIG);
}
Object value = inferenceConfig.values().iterator().next();
if ((value instanceof Map<?, ?>) == false) {
throw ExceptionsHelper.badRequestException("{} must be an object with one inference type mapped to an object.",
INFERENCE_CONFIG);
}
@SuppressWarnings("unchecked")
Map<String, Object> valueMap = (Map<String, Object>)value;
if (inferenceConfig.containsKey(ClassificationConfig.NAME)) {
checkSupportedVersion(new ClassificationConfig(0));
return ClassificationConfig.fromMap(valueMap);
} else if (inferenceConfig.containsKey(RegressionConfig.NAME)) {
checkSupportedVersion(new RegressionConfig());
return RegressionConfig.fromMap(valueMap);
} else {
throw ExceptionsHelper.badRequestException("unrecognized inference configuration type {}. Supported types {}",
inferenceConfig.keySet(),
Arrays.asList(ClassificationConfig.NAME, RegressionConfig.NAME));
}
}
void checkSupportedVersion(InferenceConfig config) {
if (config.getMinimalSupportedVersion().after(minNodeVersion)) {
throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION,
config.getName(),
config.getMinimalSupportedVersion(),
minNodeVersion));
}
}
@Override
public void licenseStateChanged() {
this.inferenceAllowed = licenseState.isMachineLearningAllowed();
}
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import java.util.Map;
public class LocalModel implements Model {
private final TrainedModelDefinition trainedModelDefinition;
private final String modelId;
public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) {
this.trainedModelDefinition = trainedModelDefinition;
this.modelId = modelId;
}
long ramBytesUsed() {
return trainedModelDefinition.ramBytesUsed();
}
@Override
public String getModelId() {
return modelId;
}
@Override
public String getResultsType() {
switch (trainedModelDefinition.getTrainedModel().targetType()) {
case CLASSIFICATION:
return ClassificationInferenceResults.NAME;
case REGRESSION:
return RegressionInferenceResults.NAME;
default:
throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]",
modelId,
trainedModelDefinition.getTrainedModel().targetType());
}
}
@Override
public void infer(Map<String, Object> fields, InferenceConfig config, ActionListener<InferenceResults> listener) {
try {
listener.onResponse(trainedModelDefinition.infer(fields, config));
} catch (Exception e) {
listener.onFailure(e);
}
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import java.util.Map;
public interface Model {
String getResultsType();
void infer(Map<String, Object> fields, InferenceConfig inferenceConfig, ActionListener<InferenceResults> listener);
String getModelId();
}

View File

@ -0,0 +1,370 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.RemovalNotification;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.TimeUnit;
/**
* This is a thread safe model loading service.
*
* It will cache local models that are referenced by processors in memory (as long as it is instantiated on an ingest node).
*
* If more than one processor references the same model, that model will only be cached once.
*/
public class ModelLoadingService implements ClusterStateListener {
/**
* The maximum size of the local model cache here in the loading service
*
* Once the limit is reached, LRU models are evicted in favor of new models
*/
public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE =
Setting.byteSizeSetting("xpack.ml.inference_model.cache_size",
new ByteSizeValue(1, ByteSizeUnit.GB),
Setting.Property.NodeScope);
/**
* How long should a model stay in the cache since its last access
*
* If nothing references a model via getModel for this configured timeValue, it will be evicted.
*
* Specifically, in the ingest scenario, a processor will call getModel whenever it needs to run inference. So, if a processor is not
* executed for an extended period of time, the model will be evicted and will have to be loaded again when getModel is called.
*
*/
public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL =
Setting.timeSetting("xpack.ml.inference_model.time_to_live",
new TimeValue(5, TimeUnit.MINUTES),
new TimeValue(1, TimeUnit.MILLISECONDS),
Setting.Property.NodeScope);
private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
private final Cache<String, LocalModel> localModelCache;
private final Set<String> referencedModels = new HashSet<>();
private final Map<String, Queue<ActionListener<Model>>> loadingListeners = new HashMap<>();
private final TrainedModelProvider provider;
private final Set<String> shouldNotAudit;
private final ThreadPool threadPool;
private final InferenceAuditor auditor;
private final ByteSizeValue maxCacheSize;
public ModelLoadingService(TrainedModelProvider trainedModelProvider,
InferenceAuditor auditor,
ThreadPool threadPool,
ClusterService clusterService,
Settings settings) {
this.provider = trainedModelProvider;
this.threadPool = threadPool;
this.maxCacheSize = INFERENCE_MODEL_CACHE_SIZE.get(settings);
this.auditor = auditor;
this.shouldNotAudit = new HashSet<>();
this.localModelCache = CacheBuilder.<String, LocalModel>builder()
.setMaximumWeight(this.maxCacheSize.getBytes())
.weigher((id, localModel) -> localModel.ramBytesUsed())
.removalListener(this::cacheEvictionListener)
.setExpireAfterAccess(INFERENCE_MODEL_CACHE_TTL.get(settings))
.build();
clusterService.addListener(this);
}
/**
* Gets the model referenced by `modelId` and responds to the listener.
*
* This method first checks the local LRU cache for the model. If it is present, it is returned from cache.
*
* If it is not present, one of the following occurs:
*
* - If the model is referenced by a pipeline and is currently being loaded, the `modelActionListener`
* is added to the list of listeners to be alerted when the model is fully loaded.
* - If the model is referenced by a pipeline and is currently NOT being loaded, a new load attempt is made and the resulting
* model will attempt to be cached for future reference
* - If the models is NOT referenced by a pipeline, the model is simply loaded from the index and given to the listener.
* It is not cached.
*
* @param modelId the model to get
* @param modelActionListener the listener to alert when the model has been retrieved.
*/
public void getModel(String modelId, ActionListener<Model> modelActionListener) {
LocalModel cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
logger.trace("[{}] loaded from cache", modelId);
return;
}
if (loadModelIfNecessary(modelId, modelActionListener) == false) {
// If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called
// by a simulated pipeline
logger.trace("[{}] not actively loading, eager loading without cache", modelId);
provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig ->
modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())),
modelActionListener::onFailure
));
} else {
logger.trace("[{}] is loading or loaded, added new listener to queue", modelId);
}
}
/**
* Returns true if the model is loaded and the listener has been given the cached model
* Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded
* Returns false if the model is not loaded or actively being loaded
*/
private boolean loadModelIfNecessary(String modelId, ActionListener<Model> modelActionListener) {
synchronized (loadingListeners) {
Model cachedModel = localModelCache.get(modelId);
if (cachedModel != null) {
modelActionListener.onResponse(cachedModel);
return true;
}
// It is referenced by a pipeline, but the cache does not contain it
if (referencedModels.contains(modelId)) {
// If the loaded model is referenced there but is not present,
// that means the previous load attempt failed or the model has been evicted
// Attempt to load and cache the model if necessary
if (loadingListeners.computeIfPresent(
modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) {
logger.trace("[{}] attempting to load and cache", modelId);
loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener));
loadModel(modelId);
}
return true;
}
// if the cachedModel entry is null, but there are listeners present, that means it is being loaded
return loadingListeners.computeIfPresent(modelId,
(storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null;
} // synchronized (loadingListeners)
}
private void loadModel(String modelId) {
provider.getTrainedModel(modelId, true, ActionListener.wrap(
trainedModelConfig -> {
logger.debug("[{}] successfully loaded model", modelId);
handleLoadSuccess(modelId, trainedModelConfig);
},
failure -> {
logger.warn(new ParameterizedMessage("[{}] failed to load model", modelId), failure);
handleLoadFailure(modelId, failure);
}
));
}
private void handleLoadSuccess(String modelId, TrainedModelConfig trainedModelConfig) {
Queue<ActionListener<Model>> listeners;
LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition());
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
// If there is no loadingListener that means the loading was canceled and the listener was already notified as such
// Consequently, we should not store the retrieved model
if (listeners == null) {
return;
}
localModelCache.put(modelId, loadedModel);
shouldNotAudit.remove(modelId);
} // synchronized (loadingListeners)
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onResponse(loadedModel);
}
}
private void handleLoadFailure(String modelId, Exception failure) {
Queue<ActionListener<Model>> listeners;
synchronized (loadingListeners) {
listeners = loadingListeners.remove(modelId);
if (listeners == null) {
return;
}
} // synchronized (loadingListeners)
// If we failed to load and there were listeners present, that means that this model is referenced by a processor
// Alert the listeners to the failure
for (ActionListener<Model> listener = listeners.poll(); listener != null; listener = listeners.poll()) {
listener.onFailure(failure);
}
}
private void cacheEvictionListener(RemovalNotification<String, LocalModel> notification) {
if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
String msg = new ParameterizedMessage(
"model cache entry evicted." +
"current cache [{}] current max [{}] model size [{}]. " +
"If this is undesired, consider updating setting [{}] or [{}].",
new ByteSizeValue(localModelCache.weight()).getStringRep(),
maxCacheSize.getStringRep(),
new ByteSizeValue(notification.getValue().ramBytesUsed()).getStringRep(),
INFERENCE_MODEL_CACHE_SIZE.getKey(),
INFERENCE_MODEL_CACHE_TTL.getKey()).getFormattedMessage();
auditIfNecessary(notification.getKey(), msg);
}
}
@Override
public void clusterChanged(ClusterChangedEvent event) {
// If ingest data has not changed or if the current node is not an ingest node, don't bother caching models
if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE) == false ||
event.state().nodes().getLocalNode().isIngestNode() == false) {
return;
}
ClusterState state = event.state();
IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE);
Set<String> allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata);
if (allReferencedModelKeys.equals(referencedModels)) {
return;
}
// The listeners still waiting for a model and we are canceling the load?
List<Tuple<String, List<ActionListener<Model>>>> drainWithFailure = new ArrayList<>();
Set<String> referencedModelsBeforeClusterState = null;
Set<String> loadingModelBeforeClusterState = null;
Set<String> removedModels = null;
synchronized (loadingListeners) {
referencedModelsBeforeClusterState = new HashSet<>(referencedModels);
if (logger.isTraceEnabled()) {
loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet());
}
// If we had models still loading here but are no longer referenced
// we should remove them from loadingListeners and alert the listeners
for (String modelId : loadingListeners.keySet()) {
if (allReferencedModelKeys.contains(modelId) == false) {
drainWithFailure.add(Tuple.tuple(modelId, new ArrayList<>(loadingListeners.remove(modelId))));
}
}
removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
// Remove all cached models that are not referenced by any processors
removedModels.forEach(localModelCache::invalidate);
// Remove the models that are no longer referenced
referencedModels.removeAll(removedModels);
shouldNotAudit.removeAll(removedModels);
// Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels
allReferencedModelKeys.removeAll(referencedModels);
referencedModels.addAll(allReferencedModelKeys);
// Populate loadingListeners key so we know that we are currently loading the model
for (String modelId : allReferencedModelKeys) {
loadingListeners.put(modelId, new ArrayDeque<>());
}
} // synchronized (loadingListeners)
if (logger.isTraceEnabled()) {
if (loadingListeners.keySet().equals(loadingModelBeforeClusterState) == false) {
logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState,
loadingListeners.keySet());
}
if (referencedModels.equals(referencedModelsBeforeClusterState) == false) {
logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState,
referencedModels);
}
}
for (Tuple<String, List<ActionListener<Model>>> modelAndListeners : drainWithFailure) {
final String msg = new ParameterizedMessage(
"Cancelling load of model [{}] as it is no longer referenced by a pipeline",
modelAndListeners.v1()).getFormat();
for (ActionListener<Model> listener : modelAndListeners.v2()) {
listener.onFailure(new ElasticsearchException(msg));
}
}
removedModels.forEach(this::auditUnreferencedModel);
loadModels(allReferencedModelKeys);
}
private void auditIfNecessary(String modelId, String msg) {
if (shouldNotAudit.contains(modelId)) {
logger.trace("[{}] {}", modelId, msg);
return;
}
auditor.warning(modelId, msg);
shouldNotAudit.add(modelId);
logger.warn("[{}] {}", modelId, msg);
}
private void loadModels(Set<String> modelIds) {
if (modelIds.isEmpty()) {
return;
}
// Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool
threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> {
for (String modelId : modelIds) {
auditNewReferencedModel(modelId);
this.loadModel(modelId);
}
});
}
private void auditNewReferencedModel(String modelId) {
auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache");
}
private void auditUnreferencedModel(String modelId) {
auditor.info(modelId, "no longer referenced by any processors");
}
private static <T> Queue<T> addFluently(Queue<T> queue, T object) {
queue.add(object);
return queue;
}
private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
Set<String> allReferencedModelKeys = new HashSet<>();
if (ingestMetadata == null) {
return allReferencedModelKeys;
}
ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
Object processors = pipelineConfiguration.getConfigAsMap().get("processors");
if (processors instanceof List<?>) {
for(Object processor : (List<?>)processors) {
if (processor instanceof Map<?, ?>) {
Object processorConfig = ((Map<?, ?>)processor).get(InferenceProcessor.TYPE);
if (processorConfig instanceof Map<?, ?>) {
Object modelId = ((Map<?, ?>)processorConfig).get(InferenceProcessor.MODEL_ID);
if (modelId != null) {
assert modelId instanceof String;
allReferencedModelKeys.add(modelId.toString());
}
}
}
}
}
});
return allReferencedModelKeys;
}
}

View File

@ -24,6 +24,7 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappi
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT;
import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE;
@ -103,6 +104,12 @@ public final class InferenceInternalIndex {
.endObject()
.startObject(TrainedModelConfig.METADATA.getPreferredName())
.field(ENABLED, false)
.endObject()
.startObject(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName())
.field(TYPE, LONG)
.endObject()
.startObject(TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName())
.field(TYPE, LONG)
.endObject();
}
}

View File

@ -20,10 +20,19 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.regex.Regex;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
@ -32,12 +41,21 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.reindex.DeleteByQueryAction;
import org.elasticsearch.index.reindex.DeleteByQueryRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
@ -47,10 +65,17 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FAILED_TO_DESERIALIZE;
public class TrainedModelProvider {
@ -189,6 +214,166 @@ public class TrainedModelProvider {
multiSearchResponseActionListener);
}
/**
* Gets all the provided trained config model objects
*
* NOTE:
* This does no expansion on the ids.
* It assumes that there are fewer than 10k.
*/
public void getTrainedModels(Set<String> modelIds, boolean allowNoResources, final ActionListener<List<TrainedModelConfig>> listener) {
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0])));
SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
.addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC)
.addSort("_index", SortOrder.DESC)
.setQuery(queryBuilder)
.request();
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
searchResponse -> {
Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f);
List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length);
for(SearchHit searchHit : searchResponse.getHits().getHits()) {
try {
if (observedIds.contains(searchHit.getId()) == false) {
configs.add(
parseInferenceDocLenientlyFromSource(searchHit.getSourceRef(), searchHit.getId()).build()
);
observedIds.add(searchHit.getId());
}
} catch (IOException ex) {
listener.onFailure(
ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ex, searchHit.getId()));
return;
}
}
// We previously expanded the IDs.
// If the config has gone missing between then and now we should throw if allowNoResources is false
// Otherwise, treat it as if it was never expanded to begin with.
Set<String> missingConfigs = Sets.difference(modelIds, observedIds);
if (missingConfigs.isEmpty() == false && allowNoResources == false) {
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
return;
}
listener.onResponse(configs);
},
listener::onFailure
);
executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, configSearchHandler);
}
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
request.indices(InferenceIndexConstants.INDEX_PATTERN);
QueryBuilder query = QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId);
request.setQuery(query);
request.setRefresh(true);
executeAsyncWithOrigin(client, ML_ORIGIN, DeleteByQueryAction.INSTANCE, request, ActionListener.wrap(deleteResponse -> {
if (deleteResponse.getDeleted() == 0) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return;
}
listener.onResponse(true);
}, e -> {
if (e.getClass() == IndexNotFoundException.class) {
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
} else {
listener.onFailure(e);
}
}));
}
public void expandIds(String idExpression,
boolean allowNoResources,
@Nullable PageParams pageParams,
ActionListener<Tuple<Long, Set<String>>> idsListener) {
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
// If there are no resources, there might be no mapping for the id field.
// This makes sure we don't get an error if that happens.
.unmappedType("long"))
.query(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
if (pageParams != null) {
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
}
sourceBuilder.trackTotalHits(true)
// we only care about the item id's
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
IndicesOptions indicesOptions = SearchRequest.DEFAULT_INDICES_OPTIONS;
SearchRequest searchRequest = new SearchRequest(InferenceIndexConstants.INDEX_PATTERN)
.indicesOptions(IndicesOptions.fromOptions(true,
indicesOptions.allowNoIndices(),
indicesOptions.expandWildcardsOpen(),
indicesOptions.expandWildcardsClosed(),
indicesOptions))
.source(sourceBuilder);
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(
response -> {
Set<String> foundResourceIds = new LinkedHashSet<>();
long totalHitCount = response.getHits().getTotalHits().value;
for (SearchHit hit : response.getHits().getHits()) {
Map<String, Object> docSource = hit.getSourceAsMap();
if (docSource == null) {
continue;
}
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
if (idValue instanceof String) {
foundResourceIds.add(idValue.toString());
}
}
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
requiredMatches.filterMatchedIds(foundResourceIds);
if (requiredMatches.hasUnmatchedIds()) {
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
} else {
idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
}
},
idsListener::onFailure
),
client::search);
}
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
if (Strings.isAllOrWildcard(tokens)) {
return boolQuery;
}
// If the resourceId is not _all or *, we should see if it is a comma delimited string with wild-cards
// e.g. id1,id2*,id3
BoolQueryBuilder shouldQueries = new BoolQueryBuilder();
List<String> terms = new ArrayList<>();
for (String token : tokens) {
if (Regex.isSimpleMatchPattern(token)) {
shouldQueries.should(QueryBuilders.wildcardQuery(resourceIdField, token));
} else {
terms.add(token);
}
}
if (terms.isEmpty() == false) {
shouldQueries.should(QueryBuilders.termsQuery(resourceIdField, terms));
}
if (shouldQueries.should().isEmpty() == false) {
boolQuery.filter(shouldQueries);
}
return boolQuery;
}
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
String resourceId,
@ -202,23 +387,23 @@ public class TrainedModelProvider {
return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId);
}
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
return TrainedModelConfig.fromXContent(parser, true);
} catch (Exception e) {
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e);
throw e;
}
}
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception {
private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws IOException {
try (InputStream stream = source.streamInput();
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) {
return TrainedModelDefinition.fromXContent(parser, true).build();
} catch (Exception e) {
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e);
throw e;
}

View File

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.notifications;
import org.elasticsearch.client.Client;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
import org.elasticsearch.xpack.core.ml.notifications.AuditorField;
import org.elasticsearch.xpack.core.ml.notifications.InferenceAuditMessage;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
public class InferenceAuditor extends AbstractAuditor<InferenceAuditMessage> {
public InferenceAuditor(Client client, String nodeName) {
super(client, nodeName, AuditorField.NOTIFICATIONS_INDEX, ML_ORIGIN, InferenceAuditMessage::new);
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
import static org.elasticsearch.rest.RestRequest.Method.DELETE;
public class RestDeleteTrainedModelAction extends BaseRestHandler {
public RestDeleteTrainedModelAction(RestController controller) {
controller.registerHandler(
DELETE, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
}
@Override
public String getName() {
return "ml_delete_trained_models_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
DeleteTrainedModelAction.Request request = new DeleteTrainedModelAction.Request(modelId);
return channel -> client.execute(DeleteTrainedModelAction.INSTANCE, request, new RestToXContentListener<>(channel));
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
public class RestGetTrainedModelsAction extends BaseRestHandler {
public RestGetTrainedModelsAction(RestController controller) {
controller.registerHandler(
GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}", this);
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference", this);
}
@Override
public String getName() {
return "ml_get_trained_models_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
if (Strings.isNullOrEmpty(modelId)) {
modelId = MetaData.ALL;
}
boolean includeModelDefinition = restRequest.paramAsBoolean(
GetTrainedModelsAction.Request.INCLUDE_MODEL_DEFINITION.getPreferredName(),
false
);
GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId, includeModelDefinition);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsAction.INSTANCE, request, new RestToXContentListener<>(channel));
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.rest.inference;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.ml.MachineLearning;
import java.io.IOException;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import static org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Request.ALLOW_NO_MATCH;
public class RestGetTrainedModelsStatsAction extends BaseRestHandler {
public RestGetTrainedModelsStatsAction(RestController controller) {
controller.registerHandler(
GET, MachineLearning.BASE_PATH + "inference/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/_stats", this);
controller.registerHandler(GET, MachineLearning.BASE_PATH + "inference/_stats", this);
}
@Override
public String getName() {
return "ml_get_trained_models_stats_action";
}
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
if (Strings.isNullOrEmpty(modelId)) {
modelId = MetaData.ALL;
}
GetTrainedModelsStatsAction.Request request = new GetTrainedModelsStatsAction.Request(modelId);
if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) {
request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM),
restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE)));
}
request.setAllowNoResources(restRequest.paramAsBoolean(ALLOW_NO_MATCH.getPreferredName(), request.isAllowNoResources()));
return channel -> client.execute(GetTrainedModelsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel));
}
}

View File

@ -6,13 +6,21 @@
package org.elasticsearch.license;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ingest.PutPipelineAction;
import org.elasticsearch.action.ingest.PutPipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineAction;
import org.elasticsearch.action.ingest.SimulatePipelineRequest;
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.rest.RestStatus;
@ -24,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.PutJobAction;
@ -31,17 +40,24 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.core.ml.client.MachineLearningClient;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.LocalStateMachineLearning;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.junit.Before;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
@ -529,6 +545,216 @@ public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
}
}
public void testMachineLearningCreateInferenceProcessorRestricted() {
String modelId = "modelprocessorlicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
String pipeline = "{" +
" \"processors\": [\n" +
" {\n" +
" \"inference\": {\n" +
" \"target_field\": \"regression_value\",\n" +
" \"model_id\": \"modelprocessorlicensetest\",\n" +
" \"inference_config\": {\"regression\": {}},\n" +
" \"field_mappings\": {\n" +
" \"col1\": \"col1\",\n" +
" \"col2\": \"col2\",\n" +
" \"col3\": \"col3\",\n" +
" \"col4\": \"col4\"\n" +
" }\n" +
" }\n" +
" }]}\n";
// test that license restricted apis do now work
PlainActionFuture<AcknowledgedResponse> putPipelineListener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
putPipelineListener);
AcknowledgedResponse putPipelineResponse = putPipelineListener.actionGet();
assertTrue(putPipelineResponse.isAcknowledged());
String simulateSource = "{\n" +
" \"pipeline\": \n" +
pipeline +
" ,\n" +
" \"docs\": [\n" +
" {\"_source\": {\n" +
" \"col1\": \"female\",\n" +
" \"col2\": \"M\",\n" +
" \"col3\": \"none\",\n" +
" \"col4\": 10\n" +
" }}]\n" +
"}";
PlainActionFuture<SimulatePipelineResponse> simulatePipelineListener = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
simulatePipelineListener);
assertThat(simulatePipelineListener.actionGet().getResults(), is(not(empty())));
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// creating a new pipeline should fail
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<AcknowledgedResponse> listener = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline_failure",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
// Simulating the pipeline should fail
e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<SimulatePipelineResponse> listener = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
PlainActionFuture<AcknowledgedResponse> putPipelineListenerNewLicense = PlainActionFuture.newFuture();
client().execute(PutPipelineAction.INSTANCE,
new PutPipelineRequest("test_infer_license_pipeline",
new BytesArray(pipeline.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON),
putPipelineListenerNewLicense);
AcknowledgedResponse putPipelineResponseNewLicense = putPipelineListenerNewLicense.actionGet();
assertTrue(putPipelineResponseNewLicense.isAcknowledged());
PlainActionFuture<SimulatePipelineResponse> simulatePipelineListenerNewLicense = PlainActionFuture.newFuture();
client().execute(SimulatePipelineAction.INSTANCE,
new SimulatePipelineRequest(new BytesArray(simulateSource.getBytes(StandardCharsets.UTF_8)), XContentType.JSON),
simulatePipelineListenerNewLicense);
assertThat(simulatePipelineListenerNewLicense.actionGet().getResults(), is(not(empty())));
}
public void testMachineLearningInferModelRestricted() throws Exception {
String modelId = "modelinfermodellicensetest";
assertMLAllowed(true);
putInferenceModel(modelId);
PlainActionFuture<InferModelAction.Response> inferModelSuccess = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), inferModelSuccess);
assertThat(inferModelSuccess.actionGet().getInferenceResults(), is(not(empty())));
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// inferring against a model should now fail
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), listener);
listener.actionGet();
});
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackField.MACHINE_LEARNING));
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
PlainActionFuture<InferModelAction.Response> listener = PlainActionFuture.newFuture();
client().execute(InferModelAction.INSTANCE, new InferModelAction.Request(
modelId,
Collections.singletonList(Collections.emptyMap()),
new RegressionConfig()
), listener);
assertThat(listener.actionGet().getInferenceResults(), is(not(empty())));
}
private void putInferenceModel(String modelId) {
String config = "" +
"{\n" +
" \"model_id\": \"" + modelId + "\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" +
" \"created_by\": \"benwtrent\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0,\n" +
" \"estimated_operations\": 0,\n" +
" \"created_time\": 0\n" +
"}";
String definition = "" +
"{" +
" \"trained_model\": {\n" +
" \"tree\": {\n" +
" \"feature_names\": [\n" +
" \"col1_male\",\n" +
" \"col1_female\",\n" +
" \"col2_encoded\",\n" +
" \"col3_encoded\",\n" +
" \"col4\"\n" +
" ],\n" +
" \"tree_structure\": [\n" +
" {\n" +
" \"node_index\": 0,\n" +
" \"split_feature\": 0,\n" +
" \"split_gain\": 12.0,\n" +
" \"threshold\": 10.0,\n" +
" \"decision_type\": \"lte\",\n" +
" \"default_left\": true,\n" +
" \"left_child\": 1,\n" +
" \"right_child\": 2\n" +
" },\n" +
" {\n" +
" \"node_index\": 1,\n" +
" \"leaf_value\": 1\n" +
" },\n" +
" {\n" +
" \"node_index\": 2,\n" +
" \"leaf_value\": 2\n" +
" }\n" +
" ],\n" +
" \"target_type\": \"regression\"\n" +
" }\n" +
" }," +
" \"model_id\": \"" + modelId + "\"\n" +
"}";
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(modelId)
.setSource(config, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
assertThat(client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
.setId(TrainedModelDefinition.docId(modelId))
.setSource(definition, XContentType.JSON)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get().status(), equalTo(RestStatus.CREATED));
}
private static OperationMode randomInvalidLicenseType() {
return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC);
}

View File

@ -8,8 +8,16 @@ package org.elasticsearch.xpack.ml;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
@ -20,13 +28,18 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackFeatureSet.Usage;
import org.elasticsearch.xpack.core.XPackField;
@ -40,6 +53,7 @@ import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
@ -49,10 +63,12 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeSta
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
import org.elasticsearch.xpack.core.ml.stats.ForecastStatsTests;
import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;
import org.junit.Before;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
@ -61,6 +77,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
@ -98,6 +116,8 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
givenJobs(Collections.emptyList(), Collections.emptyList());
givenDatafeeds(Collections.emptyList());
givenDataFrameAnalytics(Collections.emptyList());
givenProcessorStats(Collections.emptyList());
givenTrainedModelConfigCount(0);
}
public void testIsRunningOnMlPlatform() {
@ -175,14 +195,54 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
buildDatafeedStats(DatafeedState.STARTED),
buildDatafeedStats(DatafeedState.STOPPED)
));
givenDataFrameAnalytics(Arrays.asList(
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED),
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STOPPED),
buildDataFrameAnalyticsStats(DataFrameAnalyticsState.STARTED)
));
givenProcessorStats(Arrays.asList(
buildNodeStats(
Arrays.asList("pipeline1", "pipeline2", "pipeline3"),
Arrays.asList(
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)),
new IngestStats.ProcessorStat(
InferenceProcessor.TYPE,
InferenceProcessor.TYPE,
new IngestStats.Stats(100, 10, 0, 1))
),
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
)
)),
buildNodeStats(
Arrays.asList("pipeline1", "pipeline2", "pipeline3"),
Arrays.asList(
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)),
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
)
))
));
givenTrainedModelConfigCount(100);
MachineLearningFeatureSet featureSet = new MachineLearningFeatureSet(TestEnvironment.newEnvironment(settings.build()),
clusterService, client, licenseState, jobManagerHolder);
clusterService, client, licenseState, jobManagerHolder);
PlainActionFuture<Usage> future = new PlainActionFuture<>();
featureSet.usage(future);
XPackFeatureSet.Usage mlUsage = future.get();
@ -258,6 +318,18 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
assertThat(source.getValue("jobs.opened.forecasts.total"), equalTo(11));
assertThat(source.getValue("jobs.opened.forecasts.forecasted_jobs"), equalTo(2));
assertThat(source.getValue("inference.trained_models._all.count"), equalTo(100));
assertThat(source.getValue("inference.ingest_processors._all.pipelines.count"), equalTo(2));
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.sum"), equalTo(130));
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.min"), equalTo(0));
assertThat(source.getValue("inference.ingest_processors._all.num_docs_processed.max"), equalTo(100));
assertThat(source.getValue("inference.ingest_processors._all.time_ms.sum"), equalTo(14));
assertThat(source.getValue("inference.ingest_processors._all.time_ms.min"), equalTo(0));
assertThat(source.getValue("inference.ingest_processors._all.time_ms.max"), equalTo(10));
assertThat(source.getValue("inference.ingest_processors._all.num_failures.sum"), equalTo(1));
assertThat(source.getValue("inference.ingest_processors._all.num_failures.min"), equalTo(0));
assertThat(source.getValue("inference.ingest_processors._all.num_failures.max"), equalTo(1));
}
}
@ -444,6 +516,34 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
}).when(client).execute(same(GetDataFrameAnalyticsStatsAction.INSTANCE), any(), any());
}
private void givenProcessorStats(List<NodeStats> stats) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<NodesStatsResponse> listener =
(ActionListener<NodesStatsResponse>) invocationOnMock.getArguments()[2];
listener.onResponse(new NodesStatsResponse(new ClusterName("_name"), stats, Collections.emptyList()));
return Void.TYPE;
}).when(client).execute(same(NodesStatsAction.INSTANCE), any(), any());
}
private void givenTrainedModelConfigCount(long count) {
when(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN))
.thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE));
ThreadPool pool = mock(ThreadPool.class);
when(pool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
when(client.threadPool()).thenReturn(pool);
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener =
(ActionListener<SearchResponse>) invocationOnMock.getArguments()[1];
SearchHits searchHits = new SearchHits(new SearchHit[0], new TotalHits(count, TotalHits.Relation.EQUAL_TO), (float)0.0);
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.getHits()).thenReturn(searchHits);
listener.onResponse(searchResponse);
return Void.TYPE;
}).when(client).search(any(), any());
}
private static Detector buildMinDetector(String fieldName) {
Detector.Builder detectorBuilder = new Detector.Builder();
detectorBuilder.setFunction("min");
@ -490,6 +590,17 @@ public class MachineLearningFeatureSetTests extends ESTestCase {
return stats;
}
private static NodeStats buildNodeStats(List<String> pipelineNames, List<List<IngestStats.ProcessorStat>> processorStats) {
IngestStats ingestStats = new IngestStats(
new IngestStats.Stats(0,0,0,0),
Collections.emptyList(),
IntStream.range(0, pipelineNames.size()).boxed().collect(Collectors.toMap(pipelineNames::get, processorStats::get)));
return new NodeStats(mock(DiscoveryNode.class),
Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null,
null, null, null, ingestStats, null);
}
private static ForecastStats buildForecastStats(long numberOfForecasts) {
return new ForecastStatsTests().createForecastStats(numberOfForecasts, numberOfForecasts);
}

View File

@ -0,0 +1,289 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.action;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.junit.Before;
import java.io.IOException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class TransportGetTrainedModelsStatsActionTests extends ESTestCase {
private static class NotInferenceProcessor implements Processor {
@Override
public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
return ingestDocument;
}
@Override
public String getType() {
return "not_inference";
}
@Override
public String getTag() {
return null;
}
static class Factory implements Processor.Factory {
@Override
public Processor create(Map<String, Processor.Factory> processorFactories, String tag, Map<String, Object> config) {
return new NotInferenceProcessor();
}
}
}
private static final IngestPlugin SKINNY_INGEST_PLUGIN = new IngestPlugin() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
Map<String, Processor.Factory> factoryMap = new HashMap<>();
XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
factoryMap.put(InferenceProcessor.TYPE,
new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
Settings.EMPTY,
parameters.ingestService,
licenseState));
factoryMap.put("not_inference", new NotInferenceProcessor.Factory());
return factoryMap;
}
};
private ClusterService clusterService;
private IngestService ingestService;
private Client client;
@Before
public void setUpVariables() {
ThreadPool tp = mock(ThreadPool.class);
client = mock(Client.class);
clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
ingestService = new IngestService(clusterService, tp, null, null,
null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client);
}
public void testInferenceIngestStatsByPipelineId() throws IOException {
List<NodeStats> nodeStatsList = Arrays.asList(
buildNodeStats(
new IngestStats.Stats(2, 2, 3, 4),
Arrays.asList(
new IngestStats.PipelineStat(
"pipeline1",
new IngestStats.Stats(0, 0, 3, 1)),
new IngestStats.PipelineStat(
"pipeline2",
new IngestStats.Stats(1, 1, 0, 1)),
new IngestStats.PipelineStat(
"pipeline3",
new IngestStats.Stats(2, 1, 1, 1))
),
Arrays.asList(
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0)),
new IngestStats.ProcessorStat(
InferenceProcessor.TYPE,
InferenceProcessor.TYPE,
new IngestStats.Stats(100, 10, 0, 1))
),
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
)
)),
buildNodeStats(
new IngestStats.Stats(15, 5, 3, 4),
Arrays.asList(
new IngestStats.PipelineStat(
"pipeline1",
new IngestStats.Stats(10, 1, 3, 1)),
new IngestStats.PipelineStat(
"pipeline2",
new IngestStats.Stats(1, 1, 0, 1)),
new IngestStats.PipelineStat(
"pipeline3",
new IngestStats.Stats(2, 1, 1, 1))
),
Arrays.asList(
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(0, 0, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(0, 0, 0, 0)),
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat(InferenceProcessor.TYPE, InferenceProcessor.TYPE, new IngestStats.Stats(5, 1, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
),
Arrays.asList(
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))
)
))
);
NodesStatsResponse response = new NodesStatsResponse(new ClusterName("_name"), nodeStatsList, Collections.emptyList());
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<String, Set<String>>(){{
put("trained_model_1", Collections.singleton("pipeline1"));
put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2")));
}};
Map<String, IngestStats> ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByPipelineId(response,
pipelineIdsByModelIds);
assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2"))));
IngestStats expectedStatsModel1 = new IngestStats(
new IngestStats.Stats(10, 1, 6, 2),
Collections.singletonList(new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2))),
Collections.singletonMap("pipeline1", Arrays.asList(
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))))
);
IngestStats expectedStatsModel2 = new IngestStats(
new IngestStats.Stats(12, 3, 6, 4),
Arrays.asList(
new IngestStats.PipelineStat("pipeline1", new IngestStats.Stats(10, 1, 6, 2)),
new IngestStats.PipelineStat("pipeline2", new IngestStats.Stats(2, 2, 0, 2))),
new HashMap<String, List<IngestStats.ProcessorStat>>() {{
put("pipeline2", Arrays.asList(
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(10, 2, 0, 0)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(20, 2, 0, 0))));
put("pipeline1", Arrays.asList(
new IngestStats.ProcessorStat("inference", "inference", new IngestStats.Stats(120, 12, 0, 1)),
new IngestStats.ProcessorStat("grok", "grok", new IngestStats.Stats(10, 1, 0, 0))));
}}
);
assertThat(ingestStatsMap, hasEntry("trained_model_1", expectedStatsModel1));
assertThat(ingestStatsMap, hasEntry("trained_model_2", expectedStatsModel2));
}
public void testPipelineIdsByModelIds() throws IOException {
String modelId1 = "trained_model_1";
String modelId2 = "trained_model_2";
String modelId3 = "trained_model_3";
Set<String> modelIds = new HashSet<>(Arrays.asList(modelId1, modelId2, modelId3));
ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3);
Map<String, Set<String>> pipelineIdsByModelIds =
TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds);
assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds));
assertThat(pipelineIdsByModelIds,
hasEntry(modelId1, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId1 + 0, "pipeline_with_model_" + modelId1 + 1))));
assertThat(pipelineIdsByModelIds,
hasEntry(modelId2, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId2 + 0, "pipeline_with_model_" + modelId2 + 1))));
assertThat(pipelineIdsByModelIds,
hasEntry(modelId3, new HashSet<>(Arrays.asList("pipeline_with_model_" + modelId3 + 0, "pipeline_with_model_" + modelId3 + 1))));
}
private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
for (String id : modelId) {
configurations.put("pipeline_with_model_" + id + 0, newConfigurationWithInferenceProcessor(id, 0));
configurations.put("pipeline_with_model_" + id + 1, newConfigurationWithInferenceProcessor(id, 1));
}
for (int i = 0; i < 3; i++) {
configurations.put("pipeline_without_model_" + i, newConfigurationWithOutInferenceProcessor(i));
}
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
return ClusterState.builder(new ClusterName("_name"))
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
.build();
}
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId, int num) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
Collections.singletonList(
Collections.singletonMap(InferenceProcessor.TYPE,
new HashMap<String, Object>() {{
put(InferenceProcessor.MODEL_ID, modelId);
put("inference_config", Collections.singletonMap("regression", Collections.emptyMap()));
put("field_mappings", Collections.emptyMap());
put("target_field", randomAlphaOfLength(10));
}}))))) {
return new PipelineConfiguration("pipeline_with_model_" + modelId + num,
BytesReference.bytes(xContentBuilder),
XContentType.JSON);
}
}
private static PipelineConfiguration newConfigurationWithOutInferenceProcessor(int i) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
Collections.singletonList(Collections.singletonMap("not_inference", Collections.emptyMap()))))) {
return new PipelineConfiguration("pipeline_without_model_" + i, BytesReference.bytes(xContentBuilder), XContentType.JSON);
}
}
private static NodeStats buildNodeStats(IngestStats.Stats overallStats,
List<IngestStats.PipelineStat> pipelineNames,
List<List<IngestStats.ProcessorStat>> processorStats) {
List<String> pipelineids = pipelineNames.stream().map(IngestStats.PipelineStat::getPipelineId).collect(Collectors.toList());
IngestStats ingestStats = new IngestStats(
overallStats,
pipelineNames,
IntStream.range(0, pipelineids.size()).boxed().collect(Collectors.toMap(pipelineids::get, processorStats::get)));
return new NodeStats(mock(DiscoveryNode.class),
Instant.now().toEpochMilli(), null, null, null, null, null, null, null, null,
null, null, null, ingestStats, null);
}
}

View File

@ -145,6 +145,8 @@ public class AnalyticsResultProcessorTests extends ESTestCase {
assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION));
assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build()));
assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames));
assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed()));
assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations()));
Map<String, Object> metadata = storedModel.getMetadata();
assertThat(metadata.size(), equalTo(1));
assertThat(metadata, hasKey("analytics_config"));

View File

@ -40,7 +40,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase<AnalyticsResu
progressPercent = randomIntBetween(0, 100);
}
if (randomBoolean()) {
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(null);
inferenceModel = TrainedModelDefinitionTests.createRandomBuilder("model");
}
return new AnalyticsResult(rowResults, progressPercent, inferenceModel);
}

View File

@ -0,0 +1,281 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.ingest;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.Version;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.ingest.Processor;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.plugins.IngestPlugin;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class InferenceProcessorFactoryTests extends ESTestCase {
private static final IngestPlugin SKINNY_PLUGIN = new IngestPlugin() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
XPackLicenseState licenseState = mock(XPackLicenseState.class);
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
return Collections.singletonMap(InferenceProcessor.TYPE,
new InferenceProcessor.Factory(parameters.client,
parameters.ingestService.getClusterService(),
Settings.EMPTY,
parameters.ingestService,
licenseState));
}
};
private Client client;
private XPackLicenseState licenseState;
private ClusterService clusterService;
private IngestService ingestService;
@Before
public void setUpVariables() {
ThreadPool tp = mock(ThreadPool.class);
client = mock(Client.class);
clusterService = mock(ClusterService.class);
ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY,
Collections.singleton(InferenceProcessor.MAX_INFERENCE_PROCESSORS));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
ingestService = new IngestService(clusterService, tp, null, null,
null, Collections.singletonList(SKINNY_PLUGIN), client);
licenseState = mock(XPackLicenseState.class);
when(licenseState.isMachineLearningAllowed()).thenReturn(true);
}
public void testNumInferenceProcessors() throws Exception {
MetaData metaData = null;
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
processorFactory.accept(buildClusterState(metaData));
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
metaData = MetaData.builder().build();
processorFactory.accept(buildClusterState(metaData));
assertThat(processorFactory.numInferenceProcessors(), equalTo(0));
processorFactory.accept(buildClusterStateWithModelReferences("model1", "model2", "model3"));
assertThat(processorFactory.numInferenceProcessors(), equalTo(3));
}
public void testCreateProcessorWithTooManyExisting() throws Exception {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.builder().put(InferenceProcessor.MAX_INFERENCE_PROCESSORS.getKey(), 1).build(),
ingestService,
licenseState);
processorFactory.accept(buildClusterStateWithModelReferences("model1"));
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", Collections.emptyMap()));
assertThat(ex.getMessage(), equalTo("Max number of inference processors reached, total inference processors [1]. " +
"Adjust the setting [xpack.ml.max_inference_processors]: [1] if a greater number is desired."));
}
public void testCreateProcessorWithInvalidInferenceConfig() {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
Map<String, Object> config = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap()));
}};
ElasticsearchStatusException ex = expectThrows(ElasticsearchStatusException.class,
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config));
assertThat(ex.getMessage(),
equalTo("unrecognized inference configuration type [unknown_type]. Supported types [classification, regression]"));
Map<String, Object> config2 = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom"));
}};
ex = expectThrows(ElasticsearchStatusException.class,
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config2));
assertThat(ex.getMessage(),
equalTo("inference_config must be an object with one inference type mapped to an object."));
Map<String, Object> config3 = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap());
}};
ex = expectThrows(ElasticsearchStatusException.class,
() -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", config3));
assertThat(ex.getMessage(),
equalTo("inference_config must be an object with one inference type mapped to an object."));
}
public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
processorFactory.accept(builderClusterStateWithModelReferences(Version.V_7_5_0, "model1"));
Map<String, Object> regression = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
}};
try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
fail("Should not have successfully created");
} catch (ElasticsearchException ex) {
assertThat(ex.getMessage(),
equalTo("Configuration [regression] requires minimum node version [7.6.0] (current minimum node version [7.5.0]"));
} catch (Exception ex) {
fail(ex.getMessage());
}
Map<String, Object> classification = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
}};
try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification);
fail("Should not have successfully created");
} catch (ElasticsearchException ex) {
assertThat(ex.getMessage(),
equalTo("Configuration [classification] requires minimum node version [7.6.0] (current minimum node version [7.5.0]"));
} catch (Exception ex) {
fail(ex.getMessage());
}
}
public void testCreateProcessor() {
InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client,
clusterService,
Settings.EMPTY,
ingestService,
licenseState);
Map<String, Object> regression = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
}};
try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", regression);
} catch (Exception ex) {
fail(ex.getMessage());
}
Map<String, Object> classification = new HashMap<String, Object>() {{
put(InferenceProcessor.FIELD_MAPPINGS, Collections.emptyMap());
put(InferenceProcessor.MODEL_ID, "my_model");
put(InferenceProcessor.TARGET_FIELD, "result");
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(ClassificationConfig.NAME,
Collections.singletonMap(ClassificationConfig.NUM_TOP_CLASSES.getPreferredName(), 1)));
}};
try {
processorFactory.create(Collections.emptyMap(), "my_inference_processor", classification);
} catch (Exception ex) {
fail(ex.getMessage());
}
}
private static ClusterState buildClusterState(MetaData metaData) {
return ClusterState.builder(new ClusterName("_name")).metaData(metaData).build();
}
private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException {
return builderClusterStateWithModelReferences(Version.CURRENT, modelId);
}
private static ClusterState builderClusterStateWithModelReferences(Version minNodeVersion, String... modelId) throws IOException {
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
for (String id : modelId) {
configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
}
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
return ClusterState.builder(new ClusterName("_name"))
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
.nodes(DiscoveryNodes.builder()
.add(new DiscoveryNode("min_node",
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
minNodeVersion))
.add(new DiscoveryNode("current_node",
new TransportAddress(InetAddress.getLoopbackAddress(), 9302),
Version.CURRENT))
.localNodeId("_node_id")
.masterNodeId("_node_id"))
.build();
}
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
Collections.singletonList(
Collections.singletonMap(InferenceProcessor.TYPE,
new HashMap<String, Object>() {{
put(InferenceProcessor.MODEL_ID, modelId);
put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME, Collections.emptyMap()));
put(InferenceProcessor.TARGET_FIELD, "new_field");
put(InferenceProcessor.FIELD_MAPPINGS, Collections.singletonMap("source", "dest"));
}}))))) {
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
}
}
}

View File

@ -0,0 +1,230 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.ingest;
import org.elasticsearch.client.Client;
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.junit.Before;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
public class InferenceProcessorTests extends ESTestCase {
private Client client;
@Before
public void setUpVariables() {
client = mock(Client.class);
}
public void testMutateDocumentWithClassification() {
String targetField = "classification_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"classification_model",
new ClassificationConfig(0),
Collections.emptyMap(),
"ml.my_processor",
true);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", null)));
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, String.class), equalTo("foo"));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
}
@SuppressWarnings("unchecked")
public void testMutateDocumentClassificationTopNClasses() {
String targetField = "classification_value_probabilities";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"classification_model",
new ClassificationConfig(2),
Collections.emptyMap(),
"ml.my_processor",
true);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
List<ClassificationInferenceResults.TopClassEntry> classes = new ArrayList<>(2);
classes.add(new ClassificationInferenceResults.TopClassEntry("foo", 0.6));
classes.add(new ClassificationInferenceResults.TopClassEntry("bar", 0.4));
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new ClassificationInferenceResults(1.0, "foo", classes)));
inferenceProcessor.mutateDocument(response, document);
assertThat((List<Map<?,?>>)document.getFieldValue(targetField, List.class),
contains(classes.stream().map(ClassificationInferenceResults.TopClassEntry::asValueMap).toArray(Map[]::new)));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "classification_model"))));
}
public void testMutateDocumentRegression() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"ml.my_processor",
true);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(Collections.singletonMap("my_processor", Collections.singletonMap("model_id", "regression_model"))));
}
public void testMutateDocumentNoModelMetaData() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"ml.my_processor",
false);
Map<String, Object> source = new HashMap<>();
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.hasField("ml"), is(false));
}
public void testMutateDocumentModelMetaDataExistingField() {
String targetField = "regression_value";
InferenceProcessor inferenceProcessor = new InferenceProcessor(client,
"my_processor",
targetField,
"regression_model",
new RegressionConfig(),
Collections.emptyMap(),
"ml.my_processor",
true);
//cannot use singleton map as attempting to mutate later
Map<String, Object> ml = new HashMap<String, Object>(){{
put("regression_prediction", 0.55);
}};
Map<String, Object> source = new HashMap<String, Object>(){{
put("ml", ml);
}};
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
InferModelAction.Response response = new InferModelAction.Response(
Collections.singletonList(new RegressionInferenceResults(0.7)));
inferenceProcessor.mutateDocument(response, document);
assertThat(document.getFieldValue(targetField, Double.class), equalTo(0.7));
assertThat(document.getFieldValue("ml", Map.class),
equalTo(new HashMap<String, Object>(){{
put("my_processor", Collections.singletonMap("model_id", "regression_model"));
put("regression_prediction", 0.55);
}}));
}
public void testGenerateRequestWithEmptyMapping() {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
InferenceProcessor processor = new InferenceProcessor(client,
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
Collections.emptyMap(),
"ml.my_processor",
false);
Map<String, Object> source = new HashMap<String, Object>(){{
put("value1", 1);
put("value2", 4);
put("categorical", "foo");
}};
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(source));
}
public void testGenerateWithMapping() {
String modelId = "model";
Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10);
Map<String, String> fieldMapping = new HashMap<String, String>(3) {{
put("value1", "new_value1");
put("value2", "new_value2");
put("categorical", "new_categorical");
}};
InferenceProcessor processor = new InferenceProcessor(client,
"my_processor",
"my_field",
modelId,
new ClassificationConfig(topNClasses),
fieldMapping,
"ml.my_processor",
false);
Map<String, Object> source = new HashMap<String, Object>(3){{
put("value1", 1);
put("categorical", "foo");
put("un_touched", "bar");
}};
Map<String, Object> ingestMetadata = new HashMap<>();
IngestDocument document = new IngestDocument(source, ingestMetadata);
Map<String, Object> expectedMap = new HashMap<String, Object>(2) {{
put("new_value1", 1);
put("new_categorical", "foo");
put("un_touched", "bar");
}};
assertThat(processor.buildRequest(document).getObjectsToInfer().get(0), equalTo(expectedMap));
}
}

View File

@ -0,0 +1,212 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
public class LocalModelTests extends ESTestCase {
public void testClassificationInfer() throws Exception {
String modelId = "classification_model";
TrainedModelDefinition definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(false))
.build();
Model model = new LocalModel(modelId, definition);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("categorical", "dog");
}};
SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), is("0.0"));
ClassificationInferenceResults classificationResult =
(ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0"));
// Test with labels
definition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildClassification(true))
.build();
model = new LocalModel(modelId, definition);
result = getSingleValue(model, fields, new ClassificationConfig(0));
assertThat(result.value(), equalTo(0.0));
assertThat(result.valueAsString(), equalTo("not_to_be"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));
assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001));
assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2));
assertThat(classificationResult.getTopClasses(), hasSize(2));
classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1));
assertThat(classificationResult.getTopClasses(), hasSize(2));
}
public void testRegression() throws Exception {
TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap())))
.setTrainedModel(buildRegression())
.build();
Model model = new LocalModel("regression_model", trainedModelDefinition);
Map<String, Object> fields = new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("categorical", "dog");
}};
SingleValueInferenceResults results = getSingleValue(model, fields, new RegressionConfig());
assertThat(results.value(), equalTo(1.3));
PlainActionFuture<InferenceResults> failedFuture = new PlainActionFuture<>();
model.infer(fields, new ClassificationConfig(2), failedFuture);
ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get);
assertThat(ex.getCause().getMessage(),
equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]"));
}
private static SingleValueInferenceResults getSingleValue(Model model,
Map<String, Object> fields,
InferenceConfig config) throws Exception {
PlainActionFuture<InferenceResults> future = new PlainActionFuture<>();
model.infer(fields, config, future);
return (SingleValueInferenceResults)future.get();
}
private static Map<String, String> oneHotMap() {
Map<String, String> oneHotEncoding = new HashMap<>();
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
return oneHotEncoding;
}
public static TrainedModel buildClassification(boolean includeLabels) {
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2)
.setThreshold(0.8)
.setSplitFeature(1)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(3)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.build();
return Ensemble.builder()
.setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null)
.setTargetType(TargetType.CLASSIFICATION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0}))
.build();
}
public static TrainedModel buildRegression() {
List<String> featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog");
Tree tree1 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(0)
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.3))
.addNode(TreeNode.builder(2)
.setThreshold(0.0)
.setSplitFeature(3)
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.1))
.addNode(TreeNode.builder(4).setLeafValue(0.2)).build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(2)
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
.setLeftChild(1)
.setRightChild(2)
.setSplitFeature(1)
.setThreshold(0.2))
.addNode(TreeNode.builder(1).setLeafValue(1.5))
.addNode(TreeNode.builder(2).setLeafValue(0.9))
.build();
return Ensemble.builder()
.setTargetType(TargetType.REGRESSION)
.setFeatureNames(featureNames)
.setTrainedModels(Arrays.asList(tree1, tree2, tree3))
.setOutputAggregator(new WeightedSum(new double[]{0.5, 0.5, 0.5}))
.build();
}
}

View File

@ -0,0 +1,363 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.inference.loadingservice;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ScalingExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
import org.junit.After;
import org.junit.Before;
import org.mockito.Mockito;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class ModelLoadingServiceTests extends ESTestCase {
private TrainedModelProvider trainedModelProvider;
private ThreadPool threadPool;
private ClusterService clusterService;
private InferenceAuditor auditor;
@Before
public void setUpComponents() {
threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME,
1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool"));
trainedModelProvider = mock(TrainedModelProvider.class);
clusterService = mock(ClusterService.class);
auditor = mock(InferenceAuditor.class);
doAnswer(a -> null).when(auditor).error(any(String.class), any(String.class));
doAnswer(a -> null).when(auditor).info(any(String.class), any(String.class));
doAnswer(a -> null).when(auditor).warning(any(String.class), any(String.class));
doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class));
when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build());
}
@After
public void terminateThreadPool() {
terminate(threadPool);
}
public void testGetCachedModels() throws Exception {
String model1 = "test-load-model-1";
String model2 = "test-load-model-2";
String model3 = "test-load-model-3";
withTrainedModel(model1, 1L);
withTrainedModel(model2, 1L);
withTrainedModel(model3, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.EMPTY);
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
String[] modelIds = new String[]{model1, model2, model3};
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
// Test invalidate cache for model3
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
// It is not referenced, so called eagerly
verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any());
}
public void testMaxCachedLimitReached() throws Exception {
String model1 = "test-cached-limit-load-model-1";
String model2 = "test-cached-limit-load-model-2";
String model3 = "test-cached-limit-load-model-3";
withTrainedModel(model1, 10L);
withTrainedModel(model2, 5L);
withTrainedModel(model3, 15L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.builder().put(ModelLoadingService.INFERENCE_MODEL_CACHE_SIZE.getKey(), new ByteSizeValue(20L)).build());
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3));
// Should have been loaded from the cluster change event
// Verify that we have at least loaded all three so that evictions occur in the following loop
assertBusy(() -> {
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
});
String[] modelIds = new String[]{model1, model2, model3};
for(int i = 0; i < 10; i++) {
// Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load)
String model = modelIds[i%2];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any());
// Only loaded requested once on the initial load from the change event
verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any());
// Load model 3, should invalidate 1
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future3 = new PlainActionFuture<>();
modelLoadingService.getModel(model3, future3);
assertThat(future3.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any());
// Load model 1, should invalidate 2
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future1 = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future1);
assertThat(future1.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
// Load model 2, should invalidate 3
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future2 = new PlainActionFuture<>();
modelLoadingService.getModel(model2, future2);
assertThat(future2.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
// Test invalidate cache for model3
// Now both model 1 and 2 should fit in cache without issues
modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2));
for(int i = 0; i < 10; i++) {
String model = modelIds[i%3];
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any());
verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any());
verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any());
verify(trainedModelProvider, Mockito.atMost(5)).getTrainedModel(eq(model3), eq(true), any());
}
public void testWhenCacheEnabledButNotIngestNode() throws Exception {
String model1 = "test-uncached-not-ingest-model-1";
withTrainedModel(model1, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.EMPTY);
modelLoadingService.clusterChanged(ingestChangedEvent(false, model1));
for(int i = 0; i < 10; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model1, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(10)).getTrainedModel(eq(model1), eq(true), any());
}
public void testGetCachedMissingModel() throws Exception {
String model = "test-load-cached-missing-model";
withMissingModel(model);
ModelLoadingService modelLoadingService =new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.EMPTY);
modelLoadingService.clusterChanged(ingestChangedEvent(model));
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
try {
future.get();
fail("Should not have succeeded in loaded model");
} catch (Exception ex) {
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(true), any());
}
public void testGetMissingModel() {
String model = "test-load-missing-model";
withMissingModel(model);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.EMPTY);
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
try {
future.get();
fail("Should not have succeeded");
} catch (Exception ex) {
assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
}
public void testGetModelEagerly() throws Exception {
String model = "test-get-model-eagerly";
withTrainedModel(model, 1L);
ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider,
auditor,
threadPool,
clusterService,
Settings.EMPTY);
for(int i = 0; i < 3; i++) {
PlainActionFuture<Model> future = new PlainActionFuture<>();
modelLoadingService.getModel(model, future);
assertThat(future.get(), is(not(nullValue())));
}
verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(true), any());
}
@SuppressWarnings("unchecked")
private void withTrainedModel(String modelId, long size) {
TrainedModelDefinition definition = mock(TrainedModelDefinition.class);
when(definition.ramBytesUsed()).thenReturn(size);
TrainedModelConfig trainedModelConfig = mock(TrainedModelConfig.class);
when(trainedModelConfig.getDefinition()).thenReturn(definition);
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onResponse(trainedModelConfig);
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
}
private void withMissingModel(String modelId) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("rawtypes")
ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2];
listener.onFailure(new ResourceNotFoundException(
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)));
return null;
}).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(true), any());
}
private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException {
return ingestChangedEvent(true, modelId);
}
private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException {
ClusterChangedEvent event = mock(ClusterChangedEvent.class);
when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE));
when(event.state()).thenReturn(buildClusterStateWithModelReferences(isIngestNode, modelId));
return event;
}
private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException {
Map<String, PipelineConfiguration> configurations = new HashMap<>(modelId.length);
for (String id : modelId) {
configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id));
}
IngestMetadata ingestMetadata = new IngestMetadata(configurations);
return ClusterState.builder(new ClusterName("_name"))
.metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata))
.nodes(DiscoveryNodes.builder().add(
new DiscoveryNode("node_name",
"node_id",
new TransportAddress(InetAddress.getLoopbackAddress(), 9300),
Collections.emptyMap(),
isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(),
Version.CURRENT))
.localNodeId("node_id")
.build())
.build();
}
private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors",
Collections.singletonList(
Collections.singletonMap(InferenceProcessor.TYPE,
Collections.singletonMap(InferenceProcessor.MODEL_ID,
modelId)))))) {
return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON);
}
}
}

View File

@ -0,0 +1,196 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.junit.Before;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification;
import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.nullValue;
public class ModelInferenceActionIT extends MlSingleNodeTestCase {
private TrainedModelProvider trainedModelProvider;
@Before
public void createComponents() throws Exception {
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
waitForMlTemplates();
}
public void testInferModels() throws Exception {
String modelId1 = "test-load-models-regression";
String modelId2 = "test-load-models-classification";
Map<String, String> oneHotEncoding = new HashMap<>();
oneHotEncoding.put("cat", "animal_cat");
oneHotEncoding.put("dog", "animal_dog");
TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setTrainedModel(buildClassification(true))
.setModelId(modelId1))
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
.build();
TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1)
.setInput(new TrainedModelInput(Arrays.asList("field1", "field2")))
.setDefinition(new TrainedModelDefinition.Builder()
.setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding)))
.setTrainedModel(buildRegression())
.setModelId(modelId2))
.setVersion(Version.CURRENT)
.setEstimatedOperations(0)
.setEstimatedHeapMemory(0)
.setCreateTime(Instant.now())
.build();
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder);
assertThat(putConfigHolder.get(), is(true));
assertThat(exceptionHolder.get(), is(nullValue()));
List<Map<String, Object>> toInfer = new ArrayList<>();
toInfer.add(new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.5);
put("categorical", "dog");
}});
toInfer.add(new HashMap<String, Object>() {{
put("foo", 0.9);
put("bar", 1.5);
put("categorical", "cat");
}});
List<Map<String, Object>> toInfer2 = new ArrayList<>();
toInfer2.add(new HashMap<String, Object>() {{
put("foo", 0.0);
put("bar", 0.01);
put("categorical", "dog");
}});
toInfer2.add(new HashMap<String, Object>() {{
put("foo", 1.0);
put("bar", 0.0);
put("categorical", "cat");
}});
// Test regression
InferModelAction.Request request = new InferModelAction.Request(modelId1, toInfer, new RegressionConfig());
InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.3, 1.25));
request = new InferModelAction.Request(modelId1, toInfer2, new RegressionConfig());
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()),
contains(1.65, 1.55));
// Test classification
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(0));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
assertThat(response.getInferenceResults()
.stream()
.map(i -> ((SingleValueInferenceResults)i).valueAsString())
.collect(Collectors.toList()),
contains("not_to_be", "to_be"));
// Get top classes
request = new InferModelAction.Request(modelId2, toInfer, new ClassificationConfig(2));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
ClassificationInferenceResults classificationInferenceResults =
(ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1);
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be"));
// they should always be in order of Most probable to least
assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(),
greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability()));
// Test that top classes restrict the number returned
request = new InferModelAction.Request(modelId2, toInfer2, new ClassificationConfig(1));
response = client().execute(InferModelAction.INSTANCE, request).actionGet();
classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0);
assertThat(classificationInferenceResults.getTopClasses(), hasSize(1));
assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be"));
}
public void testInferMissingModel() {
String model = "test-infer-missing-model";
InferModelAction.Request request = new InferModelAction.Request(model, Collections.emptyList(), new RegressionConfig());
try {
client().execute(InferModelAction.INSTANCE, request).actionGet();
} catch (ElasticsearchException ex) {
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model)));
}
}
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
return TrainedModelConfig.builder()
.setCreatedBy("ml_test")
.setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId))
.setDescription("trained model config for test")
.setModelId(modelId);
}
@Override
public NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents());
return new NamedXContentRegistry(namedXContent);
}
}

View File

@ -144,6 +144,8 @@ public class TrainedModelProviderIT extends MlSingleNodeTestCase {
.setDescription("trained model config for test")
.setModelId(modelId)
.setVersion(Version.CURRENT)
.setEstimatedHeapMemory(0)
.setEstimatedOperations(0)
.setInput(TrainedModelInputTests.createRandomInput());
}

View File

@ -0,0 +1,24 @@
{
"ml.delete_trained_model":{
"documentation":{
"url":"TODO"
},
"stability":"experimental",
"url":{
"paths":[
{
"path":"/_ml/inference/{model_id}",
"methods":[
"DELETE"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained model to delete"
}
}
}
]
}
}
}

View File

@ -0,0 +1,54 @@
{
"ml.get_trained_models":{
"documentation":{
"url":"TODO"
},
"stability":"experimental",
"url":{
"paths":[
{
"path":"/_ml/inference/{model_id}",
"methods":[
"GET"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained models to fetch"
}
}
},
{
"path":"/_ml/inference",
"methods":[
"GET"
]
}
]
},
"params":{
"allow_no_match":{
"type":"boolean",
"required":false,
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
"default":true
},
"include_model_definition":{
"type":"boolean",
"required":false,
"description":"Should the full model definition be included in the results. These definitions can be large",
"default":false
},
"from":{
"type":"int",
"description":"skips a number of trained models",
"default":0
},
"size":{
"type":"int",
"description":"specifies a max number of trained models to get",
"default":100
}
}
}
}

View File

@ -0,0 +1,48 @@
{
"ml.get_trained_models_stats":{
"documentation":{
"url":"TODO"
},
"stability":"experimental",
"url":{
"paths":[
{
"path":"/_ml/inference/{model_id}/_stats",
"methods":[
"GET"
],
"parts":{
"model_id":{
"type":"string",
"description":"The ID of the trained models stats to fetch"
}
}
},
{
"path":"/_ml/inference/_stats",
"methods":[
"GET"
]
}
]
},
"params":{
"allow_no_match":{
"type":"boolean",
"required":false,
"description":"Whether to ignore if a wildcard expression matches no trained models. (This includes `_all` string or when no trained models have been specified)",
"default":true
},
"from":{
"type":"int",
"description":"skips a number of trained models",
"default":0
},
"size":{
"type":"int",
"description":"specifies a max number of trained models to get",
"default":100
}
}
}
}

View File

@ -0,0 +1,113 @@
---
"Test get-all given no trained models exist":
- do:
ml.get_trained_models:
model_id: "_all"
- match: { count: 0 }
- match: { trained_model_configs: [] }
- do:
ml.get_trained_models:
model_id: "*"
- match: { count: 0 }
- match: { trained_model_configs: [] }
---
"Test get given missing trained model":
- do:
catch: missing
ml.get_trained_models:
model_id: "missing-trained-model"
---
"Test get given expression without matches and allow_no_match is false":
- do:
catch: missing
ml.get_trained_models:
model_id: "missing-trained-model*"
allow_no_match: false
---
"Test get given expression without matches and allow_no_match is true":
- do:
ml.get_trained_models:
model_id: "missing-trained-model*"
allow_no_match: true
- match: { count: 0 }
- match: { trained_model_configs: [] }
---
"Test delete given unused trained model":
- do:
index:
id: trained_model_config-unused-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "unused-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local"
}
- do:
indices.refresh: {}
- do:
ml.delete_trained_model:
model_id: "unused-regression-model"
- match: { acknowledged: true }
---
"Test delete with missing model":
- do:
catch: missing
ml.delete_trained_model:
model_id: "missing-trained-model"
---
"Test delete given used trained model":
- do:
index:
id: trained_model_config-used-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "used-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local"
}
- do:
indices.refresh: {}
- do:
ingest.put_pipeline:
id: "regression-model-pipeline"
body: >
{
"processors": [
{
"inference" : {
"model_id" : "used-regression-model",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
}
}
]
}
- match: { acknowledged: true }
- do:
catch: conflict
ml.delete_trained_model:
model_id: "used-regression-model"

View File

@ -0,0 +1,230 @@
setup:
- skip:
features: headers
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-unused-regression-model1-0
index: .ml-inference-000001
body: >
{
"model_id": "unused-regression-model1",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config"
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-unused-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "unused-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config"
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
index:
id: trained_model_config-used-regression-model-0
index: .ml-inference-000001
body: >
{
"model_id": "used-regression-model",
"created_by": "ml_tests",
"version": "8.0.0",
"description": "empty model for tests",
"create_time": 0,
"model_version": 0,
"model_type": "local",
"doc_type": "trained_model_config"
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
indices.refresh: {}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ingest.put_pipeline:
id: "regression-model-pipeline"
body: >
{
"processors": [
{
"inference" : {
"model_id" : "used-regression-model",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
}
}
]
}
- do:
headers:
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
ingest.put_pipeline:
id: "regression-model-pipeline-1"
body: >
{
"processors": [
{
"inference" : {
"model_id" : "used-regression-model",
"inference_config": {"regression": {}},
"target_field": "regression_field",
"field_mappings": {}
}
}
]
}
---
"Test get stats given missing trained model":
- do:
catch: missing
ml.get_trained_models_stats:
model_id: "missing-trained-model"
---
"Test get stats given expression without matches and allow_no_match is false":
- do:
catch: missing
ml.get_trained_models_stats:
model_id: "missing-trained-model*"
allow_no_match: false
---
"Test get stats given expression without matches and allow_no_match is true":
- do:
ml.get_trained_models_stats:
model_id: "missing-trained-model*"
allow_no_match: true
- match: { count: 0 }
- match: { trained_model_stats: [] }
---
"Test get stats given trained models":
- do:
ml.get_trained_models_stats:
model_id: "unused-regression-model"
- match: { count: 1 }
- do:
ml.get_trained_models_stats:
model_id: "_all"
- match: { count: 3 }
- match: { trained_model_stats.0.model_id: unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest
- match: { trained_model_stats.2.pipeline_count: 2 }
- is_true: trained_model_stats.2.ingest
- do:
ml.get_trained_models_stats:
model_id: "*"
- match: { count: 3 }
- match: { trained_model_stats.0.model_id: unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest
- match: { trained_model_stats.2.pipeline_count: 2 }
- is_true: trained_model_stats.2.ingest
- do:
ml.get_trained_models_stats:
model_id: "unused-regression-model*"
- match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
- match: { trained_model_stats.1.pipeline_count: 0 }
- is_false: trained_model_stats.1.ingest
- do:
ml.get_trained_models_stats:
model_id: "unused-regression-model*"
size: 1
- match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model }
- match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest
- do:
ml.get_trained_models_stats:
model_id: "unused-regression-model*"
from: 1
size: 1
- match: { count: 2 }
- match: { trained_model_stats.0.model_id: unused-regression-model1 }
- match: { trained_model_stats.0.pipeline_count: 0 }
- is_false: trained_model_stats.0.ingest
- do:
ml.get_trained_models_stats:
model_id: "used-regression-model"
- match: { count: 1 }
- match: { trained_model_stats.0.model_id: used-regression-model }
- match: { trained_model_stats.0.pipeline_count: 2 }
- match:
trained_model_stats.0.ingest.total:
count: 0
time_in_millis: 0
current: 0
failed: 0
- match:
trained_model_stats.0.ingest.pipelines.regression-model-pipeline:
count: 0
time_in_millis: 0
current: 0
failed: 0
processors:
- inference:
type: inference
stats:
count: 0
time_in_millis: 0
current: 0
failed: 0
- match:
trained_model_stats.0.ingest.pipelines.regression-model-pipeline-1:
count: 0
time_in_millis: 0
current: 0
failed: 0
processors:
- inference:
type: inference
stats:
count: 0
time_in_millis: 0
current: 0
failed: 0