mirror of
https://github.com/honeymoose/OpenSearch.git
synced 2025-02-08 22:14:59 +00:00
* [ML][Inference][HLRC] add GET _stats (#49562) * fixing for backport
This commit is contained in:
parent
a42003b95b
commit
b5d7c939f8
@ -61,6 +61,7 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
|
|||||||
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
|
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
|
||||||
import org.elasticsearch.client.ml.GetRecordsRequest;
|
import org.elasticsearch.client.ml.GetRecordsRequest;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
||||||
import org.elasticsearch.client.ml.MlInfoRequest;
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
||||||
import org.elasticsearch.client.ml.OpenJobRequest;
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
||||||
import org.elasticsearch.client.ml.PostCalendarEventRequest;
|
import org.elasticsearch.client.ml.PostCalendarEventRequest;
|
||||||
@ -749,6 +750,31 @@ final class MLRequestConverters {
|
|||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Request getTrainedModelsStats(GetTrainedModelsStatsRequest getTrainedModelsStatsRequest) {
|
||||||
|
String endpoint = new EndpointBuilder()
|
||||||
|
.addPathPartAsIs("_ml", "inference")
|
||||||
|
.addPathPart(Strings.collectionToCommaDelimitedString(getTrainedModelsStatsRequest.getIds()))
|
||||||
|
.addPathPart("_stats")
|
||||||
|
.build();
|
||||||
|
RequestConverters.Params params = new RequestConverters.Params();
|
||||||
|
if (getTrainedModelsStatsRequest.getPageParams() != null) {
|
||||||
|
PageParams pageParams = getTrainedModelsStatsRequest.getPageParams();
|
||||||
|
if (pageParams.getFrom() != null) {
|
||||||
|
params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString());
|
||||||
|
}
|
||||||
|
if (pageParams.getSize() != null) {
|
||||||
|
params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (getTrainedModelsStatsRequest.getAllowNoMatch() != null) {
|
||||||
|
params.putParam(GetTrainedModelsStatsRequest.ALLOW_NO_MATCH,
|
||||||
|
Boolean.toString(getTrainedModelsStatsRequest.getAllowNoMatch()));
|
||||||
|
}
|
||||||
|
Request request = new Request(HttpGet.METHOD_NAME, endpoint);
|
||||||
|
request.addParameters(params.asMap());
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
|
static Request deleteTrainedModel(DeleteTrainedModelRequest deleteRequest) {
|
||||||
String endpoint = new EndpointBuilder()
|
String endpoint = new EndpointBuilder()
|
||||||
.addPathPartAsIs("_ml", "inference")
|
.addPathPartAsIs("_ml", "inference")
|
||||||
|
@ -77,6 +77,8 @@ import org.elasticsearch.client.ml.GetRecordsRequest;
|
|||||||
import org.elasticsearch.client.ml.GetRecordsResponse;
|
import org.elasticsearch.client.ml.GetRecordsResponse;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
|
||||||
import org.elasticsearch.client.ml.MlInfoRequest;
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
||||||
import org.elasticsearch.client.ml.MlInfoResponse;
|
import org.elasticsearch.client.ml.MlInfoResponse;
|
||||||
import org.elasticsearch.client.ml.OpenJobRequest;
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
||||||
@ -2338,6 +2340,49 @@ public final class MachineLearningClient {
|
|||||||
Collections.emptySet());
|
Collections.emptySet());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets trained model stats
|
||||||
|
* <p>
|
||||||
|
* For additional info
|
||||||
|
* see <a href="TODO">
|
||||||
|
* GET Trained Model Stats documentation</a>
|
||||||
|
*
|
||||||
|
* @param request The {@link GetTrainedModelsStatsRequest}
|
||||||
|
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
|
||||||
|
* @return {@link GetTrainedModelsStatsResponse} response object
|
||||||
|
*/
|
||||||
|
public GetTrainedModelsStatsResponse getTrainedModelsStats(GetTrainedModelsStatsRequest request,
|
||||||
|
RequestOptions options) throws IOException {
|
||||||
|
return restHighLevelClient.performRequestAndParseEntity(request,
|
||||||
|
MLRequestConverters::getTrainedModelsStats,
|
||||||
|
options,
|
||||||
|
GetTrainedModelsStatsResponse::fromXContent,
|
||||||
|
Collections.emptySet());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets trained model stats asynchronously and notifies listener upon completion
|
||||||
|
* <p>
|
||||||
|
* For additional info
|
||||||
|
* see <a href="TODO">
|
||||||
|
* GET Trained Model Stats documentation</a>
|
||||||
|
*
|
||||||
|
* @param request The {@link GetTrainedModelsStatsRequest}
|
||||||
|
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
|
||||||
|
* @param listener Listener to be notified upon request completion
|
||||||
|
* @return cancellable that may be used to cancel the request
|
||||||
|
*/
|
||||||
|
public Cancellable getTrainedModelsStatsAsync(GetTrainedModelsStatsRequest request,
|
||||||
|
RequestOptions options,
|
||||||
|
ActionListener<GetTrainedModelsStatsResponse> listener) {
|
||||||
|
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
|
||||||
|
MLRequestConverters::getTrainedModelsStats,
|
||||||
|
options,
|
||||||
|
GetTrainedModelsStatsResponse::fromXContent,
|
||||||
|
listener,
|
||||||
|
Collections.emptySet());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deletes the given Trained Model
|
* Deletes the given Trained Model
|
||||||
* <p>
|
* <p>
|
||||||
|
@ -0,0 +1,103 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.Validatable;
|
||||||
|
import org.elasticsearch.client.ValidationException;
|
||||||
|
import org.elasticsearch.client.core.PageParams;
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
public class GetTrainedModelsStatsRequest implements Validatable {
|
||||||
|
|
||||||
|
public static final String ALLOW_NO_MATCH = "allow_no_match";
|
||||||
|
|
||||||
|
private final List<String> ids;
|
||||||
|
private Boolean allowNoMatch;
|
||||||
|
private PageParams pageParams;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper method to create a request that will get ALL TrainedModelStats
|
||||||
|
* @return new {@link GetTrainedModelsStatsRequest} object for the id "_all"
|
||||||
|
*/
|
||||||
|
public static GetTrainedModelsStatsRequest getAllTrainedModelStatsRequest() {
|
||||||
|
return new GetTrainedModelsStatsRequest("_all");
|
||||||
|
}
|
||||||
|
|
||||||
|
public GetTrainedModelsStatsRequest(String... ids) {
|
||||||
|
this.ids = Arrays.asList(ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getIds() {
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Boolean getAllowNoMatch() {
|
||||||
|
return allowNoMatch;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to ignore if a wildcard expression matches no trained models.
|
||||||
|
*
|
||||||
|
* @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all})
|
||||||
|
* does not match any trained models
|
||||||
|
*/
|
||||||
|
public GetTrainedModelsStatsRequest setAllowNoMatch(boolean allowNoMatch) {
|
||||||
|
this.allowNoMatch = allowNoMatch;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PageParams getPageParams() {
|
||||||
|
return pageParams;
|
||||||
|
}
|
||||||
|
|
||||||
|
public GetTrainedModelsStatsRequest setPageParams(@Nullable PageParams pageParams) {
|
||||||
|
this.pageParams = pageParams;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Optional<ValidationException> validate() {
|
||||||
|
if (ids == null || ids.isEmpty()) {
|
||||||
|
return Optional.of(ValidationException.withError("trained model id must not be null"));
|
||||||
|
}
|
||||||
|
return Optional.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
|
||||||
|
GetTrainedModelsStatsRequest other = (GetTrainedModelsStatsRequest) o;
|
||||||
|
return Objects.equals(ids, other.ids)
|
||||||
|
&& Objects.equals(allowNoMatch, other.allowNoMatch)
|
||||||
|
&& Objects.equals(pageParams, other.pageParams);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(ids, allowNoMatch, pageParams);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,86 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml;
|
||||||
|
|
||||||
|
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
|
||||||
|
public class GetTrainedModelsStatsResponse {
|
||||||
|
|
||||||
|
public static final ParseField TRAINED_MODEL_STATS = new ParseField("trained_model_stats");
|
||||||
|
public static final ParseField COUNT = new ParseField("count");
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static final ConstructingObjectParser<GetTrainedModelsStatsResponse, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
"get_trained_model_stats",
|
||||||
|
true,
|
||||||
|
args -> new GetTrainedModelsStatsResponse((List<TrainedModelStats>) args[0], (Long) args[1]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareObjectArray(constructorArg(), (p, c) -> TrainedModelStats.fromXContent(p), TRAINED_MODEL_STATS);
|
||||||
|
PARSER.declareLong(constructorArg(), COUNT);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static GetTrainedModelsStatsResponse fromXContent(final XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
private final List<TrainedModelStats> trainedModelStats;
|
||||||
|
private final Long count;
|
||||||
|
|
||||||
|
|
||||||
|
public GetTrainedModelsStatsResponse(List<TrainedModelStats> trainedModelStats, Long count) {
|
||||||
|
this.trainedModelStats = trainedModelStats;
|
||||||
|
this.count = count;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<TrainedModelStats> getTrainedModelStats() {
|
||||||
|
return trainedModelStats;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return The total count of the trained models that matched the ID pattern.
|
||||||
|
*/
|
||||||
|
public Long getCount() {
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object o) {
|
||||||
|
if (this == o) return true;
|
||||||
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
|
|
||||||
|
GetTrainedModelsStatsResponse other = (GetTrainedModelsStatsResponse) o;
|
||||||
|
return Objects.equals(this.trainedModelStats, other.trainedModelStats) && Objects.equals(this.count, other.count);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(trainedModelStats, count);
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,123 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.Nullable;
|
||||||
|
import org.elasticsearch.common.ParseField;
|
||||||
|
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.ingest.IngestStats;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||||
|
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
|
||||||
|
|
||||||
|
public class TrainedModelStats implements ToXContentObject {
|
||||||
|
|
||||||
|
public static final ParseField MODEL_ID = new ParseField("model_id");
|
||||||
|
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
|
||||||
|
public static final ParseField INGEST_STATS = new ParseField("ingest");
|
||||||
|
|
||||||
|
private final String modelId;
|
||||||
|
private final Map<String, Object> ingestStats;
|
||||||
|
private final int pipelineCount;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static final ConstructingObjectParser<TrainedModelStats, Void> PARSER =
|
||||||
|
new ConstructingObjectParser<>(
|
||||||
|
"trained_model_stats",
|
||||||
|
true,
|
||||||
|
args -> new TrainedModelStats((String) args[0], (Map<String, Object>) args[1], (Integer) args[2]));
|
||||||
|
|
||||||
|
static {
|
||||||
|
PARSER.declareString(constructorArg(), MODEL_ID);
|
||||||
|
PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), INGEST_STATS);
|
||||||
|
PARSER.declareInt(constructorArg(), PIPELINE_COUNT);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelStats fromXContent(XContentParser parser) {
|
||||||
|
return PARSER.apply(parser, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainedModelStats(String modelId, Map<String, Object> ingestStats, int pipelineCount) {
|
||||||
|
this.modelId = modelId;
|
||||||
|
this.ingestStats = ingestStats;
|
||||||
|
this.pipelineCount = pipelineCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The model id for which the stats apply
|
||||||
|
*/
|
||||||
|
public String getModelId() {
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ingest level statistics. See {@link IngestStats#toXContent(XContentBuilder, Params)} for fields and format
|
||||||
|
* If there are no ingest pipelines referencing the model, then the ingest statistics could be null.
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
public Map<String, Object> getIngestStats() {
|
||||||
|
return ingestStats;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The total number of pipelines that reference the trained model
|
||||||
|
*/
|
||||||
|
public int getPipelineCount() {
|
||||||
|
return pipelineCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
|
builder.startObject();
|
||||||
|
builder.field(MODEL_ID.getPreferredName(), modelId);
|
||||||
|
builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
|
||||||
|
if (ingestStats != null) {
|
||||||
|
builder.field(INGEST_STATS.getPreferredName(), ingestStats);
|
||||||
|
}
|
||||||
|
builder.endObject();
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(modelId, ingestStats, pipelineCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (obj == null) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (getClass() != obj.getClass()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
TrainedModelStats other = (TrainedModelStats) obj;
|
||||||
|
return Objects.equals(this.modelId, other.modelId)
|
||||||
|
&& Objects.equals(this.ingestStats, other.ingestStats)
|
||||||
|
&& Objects.equals(this.pipelineCount, other.pipelineCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -59,6 +59,7 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
|
|||||||
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
|
import org.elasticsearch.client.ml.GetOverallBucketsRequest;
|
||||||
import org.elasticsearch.client.ml.GetRecordsRequest;
|
import org.elasticsearch.client.ml.GetRecordsRequest;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
||||||
import org.elasticsearch.client.ml.MlInfoRequest;
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
||||||
import org.elasticsearch.client.ml.OpenJobRequest;
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
||||||
import org.elasticsearch.client.ml.PostCalendarEventRequest;
|
import org.elasticsearch.client.ml.PostCalendarEventRequest;
|
||||||
@ -825,7 +826,6 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||||||
Request request = MLRequestConverters.getTrainedModels(getRequest);
|
Request request = MLRequestConverters.getTrainedModels(getRequest);
|
||||||
assertEquals(HttpGet.METHOD_NAME, request.getMethod());
|
assertEquals(HttpGet.METHOD_NAME, request.getMethod());
|
||||||
assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3, request.getEndpoint());
|
assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3, request.getEndpoint());
|
||||||
assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false")));
|
|
||||||
assertThat(request.getParameters(),
|
assertThat(request.getParameters(),
|
||||||
allOf(
|
allOf(
|
||||||
hasEntry("from", "100"),
|
hasEntry("from", "100"),
|
||||||
@ -837,6 +837,26 @@ public class MLRequestConvertersTests extends ESTestCase {
|
|||||||
assertNull(request.getEntity());
|
assertNull(request.getEntity());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGetTrainedModelsStats() {
|
||||||
|
String modelId1 = randomAlphaOfLength(10);
|
||||||
|
String modelId2 = randomAlphaOfLength(10);
|
||||||
|
String modelId3 = randomAlphaOfLength(10);
|
||||||
|
GetTrainedModelsStatsRequest getRequest = new GetTrainedModelsStatsRequest(modelId1, modelId2, modelId3)
|
||||||
|
.setAllowNoMatch(false)
|
||||||
|
.setPageParams(new PageParams(100, 300));
|
||||||
|
|
||||||
|
Request request = MLRequestConverters.getTrainedModelsStats(getRequest);
|
||||||
|
assertEquals(HttpGet.METHOD_NAME, request.getMethod());
|
||||||
|
assertEquals("/_ml/inference/" + modelId1 + "," + modelId2 + "," + modelId3 + "/_stats", request.getEndpoint());
|
||||||
|
assertThat(request.getParameters(),
|
||||||
|
allOf(
|
||||||
|
hasEntry("from", "100"),
|
||||||
|
hasEntry("size", "300"),
|
||||||
|
hasEntry("allow_no_match", "false")
|
||||||
|
));
|
||||||
|
assertNull(request.getEntity());
|
||||||
|
}
|
||||||
|
|
||||||
public void testDeleteTrainedModel() {
|
public void testDeleteTrainedModel() {
|
||||||
DeleteTrainedModelRequest deleteRequest = new DeleteTrainedModelRequest(randomAlphaOfLength(10));
|
DeleteTrainedModelRequest deleteRequest = new DeleteTrainedModelRequest(randomAlphaOfLength(10));
|
||||||
Request request = MLRequestConverters.deleteTrainedModel(deleteRequest);
|
Request request = MLRequestConverters.deleteTrainedModel(deleteRequest);
|
||||||
|
@ -24,6 +24,7 @@ import org.elasticsearch.action.bulk.BulkRequest;
|
|||||||
import org.elasticsearch.action.get.GetRequest;
|
import org.elasticsearch.action.get.GetRequest;
|
||||||
import org.elasticsearch.action.get.GetResponse;
|
import org.elasticsearch.action.get.GetResponse;
|
||||||
import org.elasticsearch.action.index.IndexRequest;
|
import org.elasticsearch.action.index.IndexRequest;
|
||||||
|
import org.elasticsearch.action.ingest.PutPipelineRequest;
|
||||||
import org.elasticsearch.action.support.WriteRequest;
|
import org.elasticsearch.action.support.WriteRequest;
|
||||||
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
import org.elasticsearch.action.support.master.AcknowledgedResponse;
|
||||||
import org.elasticsearch.action.update.UpdateRequest;
|
import org.elasticsearch.action.update.UpdateRequest;
|
||||||
@ -77,6 +78,8 @@ import org.elasticsearch.client.ml.GetModelSnapshotsRequest;
|
|||||||
import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
|
import org.elasticsearch.client.ml.GetModelSnapshotsResponse;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
|
||||||
import org.elasticsearch.client.ml.MlInfoRequest;
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
||||||
import org.elasticsearch.client.ml.MlInfoResponse;
|
import org.elasticsearch.client.ml.MlInfoResponse;
|
||||||
import org.elasticsearch.client.ml.OpenJobRequest;
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
||||||
@ -148,6 +151,8 @@ import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
|||||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||||
|
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||||
import org.elasticsearch.client.ml.job.config.DataDescription;
|
import org.elasticsearch.client.ml.job.config.DataDescription;
|
||||||
import org.elasticsearch.client.ml.job.config.Detector;
|
import org.elasticsearch.client.ml.job.config.Detector;
|
||||||
@ -157,6 +162,7 @@ import org.elasticsearch.client.ml.job.config.JobUpdate;
|
|||||||
import org.elasticsearch.client.ml.job.config.MlFilter;
|
import org.elasticsearch.client.ml.job.config.MlFilter;
|
||||||
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
|
import org.elasticsearch.client.ml.job.process.ModelSnapshot;
|
||||||
import org.elasticsearch.client.ml.job.stats.JobStats;
|
import org.elasticsearch.client.ml.job.stats.JobStats;
|
||||||
|
import org.elasticsearch.common.bytes.BytesArray;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
||||||
import org.elasticsearch.common.unit.ByteSizeUnit;
|
import org.elasticsearch.common.unit.ByteSizeUnit;
|
||||||
@ -2123,6 +2129,67 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGetTrainedModelsStats() throws Exception {
|
||||||
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
|
String modelIdPrefix = "get-trained-model-stats-";
|
||||||
|
int numberOfModels = 5;
|
||||||
|
for (int i = 0; i < numberOfModels; ++i) {
|
||||||
|
String modelId = modelIdPrefix + i;
|
||||||
|
putTrainedModel(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
String regressionPipeline = "{" +
|
||||||
|
" \"processors\": [\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"inference\": {\n" +
|
||||||
|
" \"target_field\": \"regression_value\",\n" +
|
||||||
|
" \"model_id\": \"" + modelIdPrefix + 0 + "\",\n" +
|
||||||
|
" \"inference_config\": {\"regression\": {}},\n" +
|
||||||
|
" \"field_mappings\": {\n" +
|
||||||
|
" \"col1\": \"col1\",\n" +
|
||||||
|
" \"col2\": \"col2\",\n" +
|
||||||
|
" \"col3\": \"col3\",\n" +
|
||||||
|
" \"col4\": \"col4\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }]}\n";
|
||||||
|
|
||||||
|
highLevelClient().ingest().putPipeline(
|
||||||
|
new PutPipelineRequest("regression-stats-pipeline",
|
||||||
|
new BytesArray(regressionPipeline.getBytes(StandardCharsets.UTF_8)),
|
||||||
|
XContentType.JSON),
|
||||||
|
RequestOptions.DEFAULT);
|
||||||
|
{
|
||||||
|
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
||||||
|
GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
|
||||||
|
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels));
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
||||||
|
new GetTrainedModelsStatsRequest(modelIdPrefix + 4, modelIdPrefix + 2, modelIdPrefix + 3),
|
||||||
|
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(3));
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(3L));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
||||||
|
new GetTrainedModelsStatsRequest(modelIdPrefix + "*").setPageParams(new PageParams(1, 2)),
|
||||||
|
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(2));
|
||||||
|
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
|
||||||
|
assertThat(
|
||||||
|
getTrainedModelsStatsResponse.getTrainedModelStats()
|
||||||
|
.stream()
|
||||||
|
.map(TrainedModelStats::getModelId)
|
||||||
|
.collect(Collectors.toList()),
|
||||||
|
containsInAnyOrder(modelIdPrefix + 1, modelIdPrefix + 2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testDeleteTrainedModel() throws Exception {
|
public void testDeleteTrainedModel() throws Exception {
|
||||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
String modelId = "delete-trained-model-test";
|
String modelId = "delete-trained-model-test";
|
||||||
@ -2328,7 +2395,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void putTrainedModel(String modelId) throws IOException {
|
private void putTrainedModel(String modelId) throws IOException {
|
||||||
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
|
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION).build();
|
||||||
highLevelClient().index(
|
highLevelClient().index(
|
||||||
new IndexRequest(".ml-inference-000001")
|
new IndexRequest(".ml-inference-000001")
|
||||||
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
|
||||||
|
@ -91,6 +91,8 @@ import org.elasticsearch.client.ml.GetRecordsRequest;
|
|||||||
import org.elasticsearch.client.ml.GetRecordsResponse;
|
import org.elasticsearch.client.ml.GetRecordsResponse;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
import org.elasticsearch.client.ml.GetTrainedModelsRequest;
|
||||||
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
import org.elasticsearch.client.ml.GetTrainedModelsResponse;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsRequest;
|
||||||
|
import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse;
|
||||||
import org.elasticsearch.client.ml.MlInfoRequest;
|
import org.elasticsearch.client.ml.MlInfoRequest;
|
||||||
import org.elasticsearch.client.ml.MlInfoResponse;
|
import org.elasticsearch.client.ml.MlInfoResponse;
|
||||||
import org.elasticsearch.client.ml.OpenJobRequest;
|
import org.elasticsearch.client.ml.OpenJobRequest;
|
||||||
@ -163,6 +165,7 @@ import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
|
|||||||
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.client.ml.inference.TrainedModelConfig;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
import org.elasticsearch.client.ml.inference.TrainedModelDefinitionTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.TrainedModelStats;
|
||||||
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
|
||||||
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
|
import org.elasticsearch.client.ml.job.config.AnalysisLimits;
|
||||||
import org.elasticsearch.client.ml.job.config.DataDescription;
|
import org.elasticsearch.client.ml.job.config.DataDescription;
|
||||||
@ -3593,6 +3596,58 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testGetTrainedModelsStats() throws Exception {
|
||||||
|
putTrainedModel("my-trained-model");
|
||||||
|
RestHighLevelClient client = highLevelClient();
|
||||||
|
{
|
||||||
|
// tag::get-trained-models-stats-request
|
||||||
|
GetTrainedModelsStatsRequest request =
|
||||||
|
new GetTrainedModelsStatsRequest("my-trained-model") // <1>
|
||||||
|
.setPageParams(new PageParams(0, 1)) // <2>
|
||||||
|
.setAllowNoMatch(true); // <3>
|
||||||
|
// end::get-trained-models-stats-request
|
||||||
|
|
||||||
|
// tag::get-trained-models-stats-execute
|
||||||
|
GetTrainedModelsStatsResponse response =
|
||||||
|
client.machineLearning().getTrainedModelsStats(request, RequestOptions.DEFAULT);
|
||||||
|
// end::get-trained-models-stats-execute
|
||||||
|
|
||||||
|
// tag::get-trained-models-stats-response
|
||||||
|
List<TrainedModelStats> models = response.getTrainedModelStats();
|
||||||
|
// end::get-trained-models-stats-response
|
||||||
|
|
||||||
|
assertThat(models, hasSize(1));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GetTrainedModelsStatsRequest request = new GetTrainedModelsStatsRequest("my-trained-model");
|
||||||
|
|
||||||
|
// tag::get-trained-models-stats-execute-listener
|
||||||
|
ActionListener<GetTrainedModelsStatsResponse> listener = new ActionListener<GetTrainedModelsStatsResponse>() {
|
||||||
|
@Override
|
||||||
|
public void onResponse(GetTrainedModelsStatsResponse response) {
|
||||||
|
// <1>
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onFailure(Exception e) {
|
||||||
|
// <2>
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// end::get-trained-models-stats-execute-listener
|
||||||
|
|
||||||
|
// Replace the empty listener by a blocking listener in test
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
listener = new LatchedActionListener<>(listener, latch);
|
||||||
|
|
||||||
|
// tag::get-trained-models-stats-execute-async
|
||||||
|
client.machineLearning()
|
||||||
|
.getTrainedModelsStatsAsync(request, RequestOptions.DEFAULT, listener); // <1>
|
||||||
|
// end::get-trained-models-stats-execute-async
|
||||||
|
|
||||||
|
assertTrue(latch.await(30L, TimeUnit.SECONDS));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testDeleteTrainedModel() throws Exception {
|
public void testDeleteTrainedModel() throws Exception {
|
||||||
RestHighLevelClient client = highLevelClient();
|
RestHighLevelClient client = highLevelClient();
|
||||||
{
|
{
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.elasticsearch.client.ml;
|
||||||
|
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
|
||||||
|
public class GetTrainedModelsStatsRequestTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testValidate_Ok() {
|
||||||
|
assertEquals(Optional.empty(), new GetTrainedModelsStatsRequest("valid-id").validate());
|
||||||
|
assertEquals(Optional.empty(), new GetTrainedModelsStatsRequest("").validate());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testValidate_Failure() {
|
||||||
|
assertThat(new GetTrainedModelsStatsRequest(new String[0]).validate().get().getMessage(),
|
||||||
|
containsString("trained model id must not be null"));
|
||||||
|
}
|
||||||
|
}
|
@ -21,6 +21,7 @@ package org.elasticsearch.client.ml.inference;
|
|||||||
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncodingTests;
|
||||||
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncodingTests;
|
||||||
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncodingTests;
|
||||||
|
import org.elasticsearch.client.ml.inference.trainedmodel.TargetType;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.EnsembleTests;
|
||||||
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
import org.elasticsearch.client.ml.inference.trainedmodel.tree.TreeTests;
|
||||||
import org.elasticsearch.common.settings.Settings;
|
import org.elasticsearch.common.settings.Settings;
|
||||||
@ -56,6 +57,10 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
public static TrainedModelDefinition.Builder createRandomBuilder() {
|
||||||
|
return createRandomBuilder(randomFrom(TargetType.values()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targetType) {
|
||||||
int numberOfProcessors = randomIntBetween(1, 10);
|
int numberOfProcessors = randomIntBetween(1, 10);
|
||||||
return new TrainedModelDefinition.Builder()
|
return new TrainedModelDefinition.Builder()
|
||||||
.setPreProcessors(
|
.setPreProcessors(
|
||||||
@ -65,7 +70,8 @@ public class TrainedModelDefinitionTests extends AbstractXContentTestCase<Traine
|
|||||||
TargetMeanEncodingTests.createRandom()))
|
TargetMeanEncodingTests.createRandom()))
|
||||||
.limit(numberOfProcessors)
|
.limit(numberOfProcessors)
|
||||||
.collect(Collectors.toList()))
|
.collect(Collectors.toList()))
|
||||||
.setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom()));
|
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Collections.emptyList(), 6, targetType),
|
||||||
|
EnsembleTests.createRandom(targetType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -0,0 +1,96 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to Elasticsearch under one or more contributor
|
||||||
|
* license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright
|
||||||
|
* ownership. Elasticsearch licenses this file to you under
|
||||||
|
* the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
* not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing,
|
||||||
|
* software distributed under the License is distributed on an
|
||||||
|
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
* KIND, either express or implied. See the License for the
|
||||||
|
* specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.client.ml.inference;
|
||||||
|
|
||||||
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
|
import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
|
import org.elasticsearch.ingest.IngestStats;
|
||||||
|
import org.elasticsearch.test.AbstractXContentTestCase;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.function.Function;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
|
||||||
|
public class TrainedModelStatsTests extends AbstractXContentTestCase<TrainedModelStats> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelStats doParseInstance(XContentParser parser) throws IOException {
|
||||||
|
return TrainedModelStats.fromXContent(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected boolean supportsUnknownFields() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Predicate<String> getRandomFieldsExcludeFilter() {
|
||||||
|
return field -> !field.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected TrainedModelStats createTestInstance() {
|
||||||
|
return new TrainedModelStats(
|
||||||
|
randomAlphaOfLength(10),
|
||||||
|
randomBoolean() ? null : randomIngestStats(),
|
||||||
|
randomInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Object> randomIngestStats() {
|
||||||
|
try {
|
||||||
|
List<String> pipelineIds = Stream.generate(()-> randomAlphaOfLength(10))
|
||||||
|
.limit(randomIntBetween(0, 10))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
IngestStats stats = new IngestStats(
|
||||||
|
new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()),
|
||||||
|
pipelineIds.stream().map(id -> new IngestStats.PipelineStat(id, randomStats())).collect(Collectors.toList()),
|
||||||
|
pipelineIds.stream().collect(Collectors.toMap(Function.identity(), (v) -> randomProcessorStats())));
|
||||||
|
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
|
||||||
|
builder.startObject();
|
||||||
|
stats.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||||
|
builder.endObject();
|
||||||
|
return XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
|
||||||
|
}
|
||||||
|
} catch (IOException ex) {
|
||||||
|
fail(ex.getMessage());
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private IngestStats.Stats randomStats(){
|
||||||
|
return new IngestStats.Stats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong());
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<IngestStats.ProcessorStat> randomProcessorStats() {
|
||||||
|
return Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
|
.limit(randomIntBetween(0, 10))
|
||||||
|
.map(name -> new IngestStats.ProcessorStat(name, "inference", randomStats()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -57,12 +57,16 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static Ensemble createRandom() {
|
public static Ensemble createRandom() {
|
||||||
|
return createRandom(randomFrom(TargetType.values()));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Ensemble createRandom(TargetType targetType) {
|
||||||
int numberOfFeatures = randomIntBetween(1, 10);
|
int numberOfFeatures = randomIntBetween(1, 10);
|
||||||
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
|
List<String> featureNames = Stream.generate(() -> randomAlphaOfLength(10))
|
||||||
.limit(numberOfFeatures)
|
.limit(numberOfFeatures)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
int numberOfModels = randomIntBetween(1, 10);
|
int numberOfModels = randomIntBetween(1, 10);
|
||||||
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6))
|
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
|
||||||
.limit(numberOfFeatures)
|
.limit(numberOfFeatures)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
OutputAggregator outputAggregator = null;
|
OutputAggregator outputAggregator = null;
|
||||||
@ -77,7 +81,7 @@ public class EnsembleTests extends AbstractXContentTestCase<Ensemble> {
|
|||||||
return new Ensemble(featureNames,
|
return new Ensemble(featureNames,
|
||||||
models,
|
models,
|
||||||
outputAggregator,
|
outputAggregator,
|
||||||
randomFrom(TargetType.values()),
|
targetType,
|
||||||
categoryLabels);
|
categoryLabels);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,10 +57,10 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||||||
for (int i = 0; i < numberOfFeatures; i++) {
|
for (int i = 0; i < numberOfFeatures; i++) {
|
||||||
featureNames.add(randomAlphaOfLength(10));
|
featureNames.add(randomAlphaOfLength(10));
|
||||||
}
|
}
|
||||||
return buildRandomTree(featureNames, 6);
|
return buildRandomTree(featureNames, 6, randomFrom(TargetType.values()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Tree buildRandomTree(List<String> featureNames, int depth) {
|
public static Tree buildRandomTree(List<String> featureNames, int depth, TargetType targetType) {
|
||||||
int numFeatures = featureNames.size();
|
int numFeatures = featureNames.size();
|
||||||
Tree.Builder builder = Tree.builder();
|
Tree.Builder builder = Tree.builder();
|
||||||
builder.setFeatureNames(featureNames);
|
builder.setFeatureNames(featureNames);
|
||||||
@ -88,7 +88,7 @@ public class TreeTests extends AbstractXContentTestCase<Tree> {
|
|||||||
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
|
||||||
}
|
}
|
||||||
return builder.setClassificationLabels(categoryLabels)
|
return builder.setClassificationLabels(categoryLabels)
|
||||||
.setTargetType(randomFrom(TargetType.values()))
|
.setTargetType(targetType)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
--
|
||||||
|
:api: get-trained-models-stats
|
||||||
|
:request: GetTrainedModelsStatsRequest
|
||||||
|
:response: GetTrainedModelsStatsResponse
|
||||||
|
--
|
||||||
|
[role="xpack"]
|
||||||
|
[id="{upid}-{api}"]
|
||||||
|
=== Get Trained Models Stats API
|
||||||
|
|
||||||
|
experimental[]
|
||||||
|
|
||||||
|
Retrieves one or more Trained Model statistics.
|
||||||
|
The API accepts a +{request}+ object and returns a +{response}+.
|
||||||
|
|
||||||
|
[id="{upid}-{api}-request"]
|
||||||
|
==== Get Trained Models Stats request
|
||||||
|
|
||||||
|
A +{request}+ requires either a Trained Model ID, a comma-separated list of
|
||||||
|
IDs, or the special wildcard `_all` to get stats for all Trained Models.
|
||||||
|
|
||||||
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
|
--------------------------------------------------
|
||||||
|
include-tagged::{doc-tests-file}[{api}-request]
|
||||||
|
--------------------------------------------------
|
||||||
|
<1> Constructing a new GET request referencing an existing Trained Model
|
||||||
|
<2> Set the paging parameters
|
||||||
|
<3> Allow empty response if no Trained Models match the provided ID patterns.
|
||||||
|
If false, an error will be thrown if no Trained Models match the
|
||||||
|
ID patterns.
|
||||||
|
|
||||||
|
include::../execution.asciidoc[]
|
||||||
|
|
||||||
|
[id="{upid}-{api}-response"]
|
||||||
|
==== Response
|
||||||
|
|
||||||
|
The returned +{response}+ contains the statistics
|
||||||
|
for the requested Trained Model.
|
||||||
|
|
||||||
|
["source","java",subs="attributes,callouts,macros"]
|
||||||
|
--------------------------------------------------
|
||||||
|
include-tagged::{doc-tests-file}[{api}-response]
|
||||||
|
--------------------------------------------------
|
@ -302,6 +302,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
|
|||||||
* <<{upid}-evaluate-data-frame>>
|
* <<{upid}-evaluate-data-frame>>
|
||||||
* <<{upid}-explain-data-frame-analytics>>
|
* <<{upid}-explain-data-frame-analytics>>
|
||||||
* <<{upid}-get-trained-models>>
|
* <<{upid}-get-trained-models>>
|
||||||
|
* <<{upid}-get-trained-models-stats>>
|
||||||
* <<{upid}-delete-trained-model>>
|
* <<{upid}-delete-trained-model>>
|
||||||
* <<{upid}-put-filter>>
|
* <<{upid}-put-filter>>
|
||||||
* <<{upid}-get-filters>>
|
* <<{upid}-get-filters>>
|
||||||
@ -356,6 +357,7 @@ include::ml/stop-data-frame-analytics.asciidoc[]
|
|||||||
include::ml/evaluate-data-frame.asciidoc[]
|
include::ml/evaluate-data-frame.asciidoc[]
|
||||||
include::ml/explain-data-frame-analytics.asciidoc[]
|
include::ml/explain-data-frame-analytics.asciidoc[]
|
||||||
include::ml/get-trained-models.asciidoc[]
|
include::ml/get-trained-models.asciidoc[]
|
||||||
|
include::ml/get-trained-models-stats.asciidoc[]
|
||||||
include::ml/delete-trained-model.asciidoc[]
|
include::ml/delete-trained-model.asciidoc[]
|
||||||
include::ml/put-filter.asciidoc[]
|
include::ml/put-filter.asciidoc[]
|
||||||
include::ml/get-filters.asciidoc[]
|
include::ml/get-filters.asciidoc[]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user