This adds support for models that are shipped as resources in the ML plugin. The first of which is the `lang_ident` model.
This commit is contained in:
parent
98ca9500e8
commit
060e0a6277
|
@ -2198,8 +2198,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||||
GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(),
|
GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(),
|
||||||
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
|
machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync);
|
||||||
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels));
|
assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels + 1));
|
||||||
assertThat(getTrainedModelsResponse.getCount(), equalTo(5L));
|
assertThat(getTrainedModelsResponse.getCount(), equalTo(5L + 1));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
GetTrainedModelsResponse getTrainedModelsResponse = execute(
|
||||||
|
@ -2222,7 +2222,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
|
|
||||||
public void testGetTrainedModelsStats() throws Exception {
|
public void testGetTrainedModelsStats() throws Exception {
|
||||||
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
|
||||||
String modelIdPrefix = "get-trained-model-stats-";
|
String modelIdPrefix = "a-get-trained-model-stats-";
|
||||||
int numberOfModels = 5;
|
int numberOfModels = 5;
|
||||||
for (int i = 0; i < numberOfModels; ++i) {
|
for (int i = 0; i < numberOfModels; ++i) {
|
||||||
String modelId = modelIdPrefix + i;
|
String modelId = modelIdPrefix + i;
|
||||||
|
@ -2254,8 +2254,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
|
||||||
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute(
|
||||||
GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
|
GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(),
|
||||||
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync);
|
||||||
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels));
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels + 1));
|
||||||
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L));
|
assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L + 1));
|
||||||
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1));
|
||||||
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
|
assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0));
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
@ -98,6 +99,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
||||||
pipelineCount = in.readVInt();
|
pipelineCount = in.readVInt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public String getModelId() {
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||||
builder.startObject();
|
builder.startObject();
|
||||||
|
@ -186,6 +191,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
|
||||||
0 :
|
0 :
|
||||||
ingestStats.getPipelineStats().size()));
|
ingestStats.getPipelineStats().size()));
|
||||||
});
|
});
|
||||||
|
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
|
||||||
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
|
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -409,6 +409,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Builder clearDefinition() {
|
||||||
|
this.definition = null;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) {
|
private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) {
|
||||||
if (parsedTrainedModel == null) {
|
if (parsedTrainedModel == null) {
|
||||||
return this;
|
return this;
|
||||||
|
|
|
@ -87,10 +87,12 @@ public final class Messages {
|
||||||
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
|
||||||
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
|
||||||
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
|
||||||
|
public static final String INFERENCE_CANNOT_DELETE_MODEL =
|
||||||
|
"Unable to delete model [{0}]";
|
||||||
public static final String MODEL_DEFINITION_TRUNCATED =
|
public static final String MODEL_DEFINITION_TRUNCATED =
|
||||||
"Model definition truncated. Unable to deserialize trained model definition [{0}]";
|
"Model definition truncated. Unable to deserialize trained model definition [{0}]";
|
||||||
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
|
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
|
||||||
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
|
public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED =
|
||||||
"Getting model definition is not supported when getting more than one model";
|
"Getting model definition is not supported when getting more than one model";
|
||||||
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
|
public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing";
|
||||||
|
|
||||||
|
|
|
@ -233,6 +233,32 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
|
||||||
containsString("Could not find trained model [test_classification_missing]"));
|
containsString("Could not find trained model [test_classification_missing]"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testSimulateLangIdent() {
|
||||||
|
String source = "{\n" +
|
||||||
|
" \"pipeline\": {\n" +
|
||||||
|
" \"processors\": [\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"inference\": {\n" +
|
||||||
|
" \"inference_config\": {\"classification\":{}},\n" +
|
||||||
|
" \"model_id\": \"lang_ident_model_1\",\n" +
|
||||||
|
" \"field_mappings\": {}\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
" ]\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"docs\": [\n" +
|
||||||
|
" {\"_source\": {\n" +
|
||||||
|
" \"text\": \"this is some plain text.\"\n" +
|
||||||
|
" }}]\n" +
|
||||||
|
"}";
|
||||||
|
|
||||||
|
SimulatePipelineResponse response = client().admin().cluster()
|
||||||
|
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
|
||||||
|
XContentType.JSON).get();
|
||||||
|
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
|
||||||
|
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
|
||||||
|
}
|
||||||
|
|
||||||
private Map<String, Object> generateSourceDoc() {
|
private Map<String, Object> generateSourceDoc() {
|
||||||
return new HashMap<String, Object>(){{
|
return new HashMap<String, Object>(){{
|
||||||
put("col1", randomFrom("female", "male"));
|
put("col1", randomFrom("female", "male"));
|
||||||
|
|
|
@ -60,8 +60,8 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testGetTrainedModels() throws IOException {
|
public void testGetTrainedModels() throws IOException {
|
||||||
String modelId = "test_regression_model";
|
String modelId = "a_test_regression_model";
|
||||||
String modelId2 = "test_regression_model-2";
|
String modelId2 = "a_test_regression_model-2";
|
||||||
Request model1 = new Request("PUT",
|
Request model1 = new Request("PUT",
|
||||||
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId);
|
||||||
model1.setJsonEntity(buildRegressionModel(modelId));
|
model1.setJsonEntity(buildRegressionModel(modelId));
|
||||||
|
@ -84,36 +84,36 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
String response = EntityUtils.toString(getModel.getEntity());
|
String response = EntityUtils.toString(getModel.getEntity());
|
||||||
|
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, containsString("\"count\":1"));
|
assertThat(response, containsString("\"count\":1"));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
MachineLearning.BASE_PATH + "inference/test_regression*"));
|
MachineLearning.BASE_PATH + "inference/a_test_regression*"));
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
|
||||||
assertThat(response, not(containsString("\"definition\"")));
|
assertThat(response, not(containsString("\"definition\"")));
|
||||||
assertThat(response, containsString("\"count\":2"));
|
assertThat(response, containsString("\"count\":2"));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
|
MachineLearning.BASE_PATH + "inference/a_test_regression_model?human=true&include_model_definition=true"));
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
||||||
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
|
assertThat(response, containsString("\"estimated_heap_memory_usage\""));
|
||||||
assertThat(response, containsString("\"definition\""));
|
assertThat(response, containsString("\"definition\""));
|
||||||
assertThat(response, containsString("\"count\":1"));
|
assertThat(response, containsString("\"count\":1"));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
MachineLearning.BASE_PATH + "inference/test_regression_model?decompress_definition=false&include_model_definition=true"));
|
MachineLearning.BASE_PATH + "inference/a_test_regression_model?decompress_definition=false&include_model_definition=true"));
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\""));
|
||||||
assertThat(response, containsString("\"compressed_definition\""));
|
assertThat(response, containsString("\"compressed_definition\""));
|
||||||
assertThat(response, not(containsString("\"definition\"")));
|
assertThat(response, not(containsString("\"definition\"")));
|
||||||
|
@ -121,17 +121,17 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
|
|
||||||
ResponseException responseException = expectThrows(ResponseException.class, () ->
|
ResponseException responseException = expectThrows(ResponseException.class, () ->
|
||||||
client().performRequest(new Request("GET",
|
client().performRequest(new Request("GET",
|
||||||
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
|
MachineLearning.BASE_PATH + "inference/a_test_regression*?human=true&include_model_definition=true")));
|
||||||
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
|
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
|
||||||
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));
|
containsString(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
|
MachineLearning.BASE_PATH + "inference/a_test_regression_model,a_test_regression_model-2"));
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
|
||||||
assertThat(response, containsString("\"count\":2"));
|
assertThat(response, containsString("\"count\":2"));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET",
|
getModel = client().performRequest(new Request("GET",
|
||||||
|
@ -149,17 +149,17 @@ public class TrainedModelIT extends ESRestTestCase {
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"count\":2"));
|
assertThat(response, containsString("\"count\":3"));
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model\""));
|
||||||
assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\"")));
|
assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model-2\"")));
|
||||||
|
|
||||||
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
|
getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1"));
|
||||||
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
|
||||||
|
|
||||||
response = EntityUtils.toString(getModel.getEntity());
|
response = EntityUtils.toString(getModel.getEntity());
|
||||||
assertThat(response, containsString("\"count\":2"));
|
assertThat(response, containsString("\"count\":3"));
|
||||||
assertThat(response, not(containsString("\"model_id\":\"test_regression_model\"")));
|
assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model\"")));
|
||||||
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
|
assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\""));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testDeleteTrainedModels() throws IOException {
|
public void testDeleteTrainedModels() throws IOException {
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction<Requ
|
||||||
|
|
||||||
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
|
if (request.isIncludeModelDefinition() && totalAndIds.v2().size() > 1) {
|
||||||
listener.onFailure(
|
listener.onFailure(
|
||||||
ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)
|
ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.persistence;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
import org.apache.logging.log4j.message.ParameterizedMessage;
|
import org.apache.logging.log4j.message.ParameterizedMessage;
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
import org.elasticsearch.ElasticsearchStatusException;
|
import org.elasticsearch.ElasticsearchStatusException;
|
||||||
import org.elasticsearch.ResourceAlreadyExistsException;
|
import org.elasticsearch.ResourceAlreadyExistsException;
|
||||||
import org.elasticsearch.ResourceNotFoundException;
|
import org.elasticsearch.ResourceNotFoundException;
|
||||||
|
@ -31,6 +32,7 @@ import org.elasticsearch.common.Nullable;
|
||||||
import org.elasticsearch.common.Strings;
|
import org.elasticsearch.common.Strings;
|
||||||
import org.elasticsearch.common.bytes.BytesReference;
|
import org.elasticsearch.common.bytes.BytesReference;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
|
import org.elasticsearch.common.io.Streams;
|
||||||
import org.elasticsearch.common.regex.Regex;
|
import org.elasticsearch.common.regex.Regex;
|
||||||
import org.elasticsearch.common.util.set.Sets;
|
import org.elasticsearch.common.util.set.Sets;
|
||||||
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
|
||||||
|
@ -39,6 +41,7 @@ import org.elasticsearch.common.xcontent.ToXContent;
|
||||||
import org.elasticsearch.common.xcontent.ToXContentObject;
|
import org.elasticsearch.common.xcontent.ToXContentObject;
|
||||||
import org.elasticsearch.common.xcontent.XContentBuilder;
|
import org.elasticsearch.common.xcontent.XContentBuilder;
|
||||||
import org.elasticsearch.common.xcontent.XContentFactory;
|
import org.elasticsearch.common.xcontent.XContentFactory;
|
||||||
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
import org.elasticsearch.common.xcontent.XContentParser;
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
import org.elasticsearch.common.xcontent.XContentType;
|
||||||
import org.elasticsearch.index.IndexNotFoundException;
|
import org.elasticsearch.index.IndexNotFoundException;
|
||||||
|
@ -65,8 +68,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
import java.net.URL;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -79,6 +84,10 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FA
|
||||||
|
|
||||||
public class TrainedModelProvider {
|
public class TrainedModelProvider {
|
||||||
|
|
||||||
|
public static final Set<String> MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1");
|
||||||
|
private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/";
|
||||||
|
private static final String MODEL_RESOURCE_FILE_EXT = ".json";
|
||||||
|
|
||||||
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class);
|
||||||
private final Client client;
|
private final Client client;
|
||||||
private final NamedXContentRegistry xContentRegistry;
|
private final NamedXContentRegistry xContentRegistry;
|
||||||
|
@ -92,6 +101,12 @@ public class TrainedModelProvider {
|
||||||
|
|
||||||
public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
|
public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
|
||||||
ActionListener<Boolean> listener) {
|
ActionListener<Boolean> listener) {
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) {
|
||||||
|
listener.onFailure(new ResourceAlreadyExistsException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId())));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
|
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
|
||||||
} catch (IOException ex) {
|
} catch (IOException ex) {
|
||||||
|
@ -185,6 +200,16 @@ public class TrainedModelProvider {
|
||||||
|
|
||||||
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener<TrainedModelConfig> listener) {
|
||||||
|
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||||
|
try {
|
||||||
|
listener.onResponse(loadModelFromResource(modelId, includeDefinition == false));
|
||||||
|
return;
|
||||||
|
} catch (ElasticsearchException ex) {
|
||||||
|
listener.onFailure(ex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders
|
||||||
.idsQuery()
|
.idsQuery()
|
||||||
.addIds(modelId));
|
.addIds(modelId));
|
||||||
|
@ -268,11 +293,29 @@ public class TrainedModelProvider {
|
||||||
.addSort("_index", SortOrder.DESC)
|
.addSort("_index", SortOrder.DESC)
|
||||||
.setQuery(queryBuilder)
|
.setQuery(queryBuilder)
|
||||||
.request();
|
.request();
|
||||||
|
List<TrainedModelConfig> configs = new ArrayList<>(modelIds.size());
|
||||||
|
Set<String> modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE);
|
||||||
|
Set<String> modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds);
|
||||||
|
for(String modelId : modelsAsResource) {
|
||||||
|
try {
|
||||||
|
configs.add(loadModelFromResource(modelId, true));
|
||||||
|
} catch (ElasticsearchException ex) {
|
||||||
|
listener.onFailure(ex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (modelsInIndex.isEmpty()) {
|
||||||
|
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
||||||
|
listener.onResponse(configs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
ActionListener<SearchResponse> configSearchHandler = ActionListener.wrap(
|
||||||
searchResponse -> {
|
searchResponse -> {
|
||||||
Set<String> observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f);
|
Set<String> observedIds = new HashSet<>(
|
||||||
List<TrainedModelConfig> configs = new ArrayList<>(searchResponse.getHits().getHits().length);
|
searchResponse.getHits().getHits().length + modelsAsResource.size(),
|
||||||
|
1.0f);
|
||||||
|
observedIds.addAll(modelsAsResource);
|
||||||
for(SearchHit searchHit : searchResponse.getHits().getHits()) {
|
for(SearchHit searchHit : searchResponse.getHits().getHits()) {
|
||||||
try {
|
try {
|
||||||
if (observedIds.contains(searchHit.getId()) == false) {
|
if (observedIds.contains(searchHit.getId()) == false) {
|
||||||
|
@ -295,6 +338,8 @@ public class TrainedModelProvider {
|
||||||
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
// Ensure sorted even with the injection of locally resourced models
|
||||||
|
configs.sort(Comparator.comparing(TrainedModelConfig::getModelId));
|
||||||
listener.onResponse(configs);
|
listener.onResponse(configs);
|
||||||
},
|
},
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
|
@ -304,6 +349,10 @@ public class TrainedModelProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
|
public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener) {
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(modelId)) {
|
||||||
|
listener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, modelId)));
|
||||||
|
return;
|
||||||
|
}
|
||||||
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
|
DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false);
|
||||||
|
|
||||||
request.indices(InferenceIndexConstants.INDEX_PATTERN);
|
request.indices(InferenceIndexConstants.INDEX_PATTERN);
|
||||||
|
@ -360,8 +409,8 @@ public class TrainedModelProvider {
|
||||||
searchRequest,
|
searchRequest,
|
||||||
ActionListener.<SearchResponse>wrap(
|
ActionListener.<SearchResponse>wrap(
|
||||||
response -> {
|
response -> {
|
||||||
Set<String> foundResourceIds = new LinkedHashSet<>();
|
Set<String> foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens));
|
||||||
long totalHitCount = response.getHits().getTotalHits().value;
|
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
|
||||||
for (SearchHit hit : response.getHits().getHits()) {
|
for (SearchHit hit : response.getHits().getHits()) {
|
||||||
Map<String, Object> docSource = hit.getSourceAsMap();
|
Map<String, Object> docSource = hit.getSourceAsMap();
|
||||||
if (docSource == null) {
|
if (docSource == null) {
|
||||||
|
@ -386,6 +435,37 @@ public class TrainedModelProvider {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) {
|
||||||
|
URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT);
|
||||||
|
if (resource == null) {
|
||||||
|
logger.error("[{}] presumed stored as a resource but not found", modelId);
|
||||||
|
throw new ResourceNotFoundException(
|
||||||
|
Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId));
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
BytesReference bytes = Streams.readFully(getClass()
|
||||||
|
.getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT));
|
||||||
|
try (XContentParser parser =
|
||||||
|
XContentHelper.createParser(xContentRegistry,
|
||||||
|
LoggingDeprecationHandler.INSTANCE,
|
||||||
|
bytes,
|
||||||
|
XContentType.JSON)) {
|
||||||
|
TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true);
|
||||||
|
if (nullOutDefinition) {
|
||||||
|
builder.clearDefinition();
|
||||||
|
}
|
||||||
|
return builder.build();
|
||||||
|
} catch (IOException ioEx) {
|
||||||
|
logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx);
|
||||||
|
throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId);
|
||||||
|
}
|
||||||
|
} catch (IOException ex) {
|
||||||
|
String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage();
|
||||||
|
logger.error(msg, ex);
|
||||||
|
throw ExceptionsHelper.serverError(msg, ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) {
|
||||||
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
|
||||||
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
|
.filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME));
|
||||||
|
@ -414,6 +494,29 @@ public class TrainedModelProvider {
|
||||||
return boolQuery;
|
return boolQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Set<String> matchedResourceIds(String[] tokens) {
|
||||||
|
if (Strings.isAllOrWildcard(tokens)) {
|
||||||
|
return new HashSet<>(MODELS_STORED_AS_RESOURCE);
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<String> matchedModels = new HashSet<>();
|
||||||
|
|
||||||
|
for (String token : tokens) {
|
||||||
|
if (Regex.isSimpleMatchPattern(token)) {
|
||||||
|
for (String modelId : MODELS_STORED_AS_RESOURCE) {
|
||||||
|
if(Regex.simpleMatch(token, modelId)) {
|
||||||
|
matchedModels.add(modelId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (MODELS_STORED_AS_RESOURCE.contains(token)) {
|
||||||
|
matchedModels.add(token);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matchedModels;
|
||||||
|
}
|
||||||
|
|
||||||
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
private static <T> T handleSearchItem(MultiSearchResponse.Item item,
|
||||||
String resourceId,
|
String resourceId,
|
||||||
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|
CheckedBiFunction<BytesReference, String, T, Exception> parseLeniently) throws Exception {
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.inference.persistence;
|
||||||
|
|
||||||
|
import org.elasticsearch.ElasticsearchException;
|
||||||
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
|
import org.elasticsearch.client.Client;
|
||||||
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
|
import org.elasticsearch.test.ESTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
|
||||||
|
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.is;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
import static org.hamcrest.Matchers.nullValue;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
|
public class TrainedModelProviderTests extends ESTestCase {
|
||||||
|
|
||||||
|
public void testDeleteModelStoredAsResource() {
|
||||||
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
|
PlainActionFuture<Boolean> future = new PlainActionFuture<>();
|
||||||
|
// Should be OK as we don't make any client calls
|
||||||
|
trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future);
|
||||||
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
|
||||||
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, "lang_ident_model_1")));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testPutModelThatExistsAsResource() {
|
||||||
|
TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build();
|
||||||
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
|
PlainActionFuture<Boolean> future = new PlainActionFuture<>();
|
||||||
|
trainedModelProvider.storeTrainedModel(config, future);
|
||||||
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet);
|
||||||
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, "lang_ident_model_1")));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGetModelThatExistsAsResource() throws Exception {
|
||||||
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
|
for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) {
|
||||||
|
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||||
|
trainedModelProvider.getTrainedModel(modelId, true, future);
|
||||||
|
TrainedModelConfig configWithDefinition = future.actionGet();
|
||||||
|
|
||||||
|
assertThat(configWithDefinition.getModelId(), equalTo(modelId));
|
||||||
|
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
||||||
|
|
||||||
|
PlainActionFuture<TrainedModelConfig> futureNoDefinition = new PlainActionFuture<>();
|
||||||
|
trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition);
|
||||||
|
TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet();
|
||||||
|
|
||||||
|
assertThat(configWithoutDefinition.getModelId(), equalTo(modelId));
|
||||||
|
assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue())));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testGetModelThatExistsAsResourceButIsMissing() {
|
||||||
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
|
ElasticsearchException ex = expectThrows(ElasticsearchException.class,
|
||||||
|
() -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean()));
|
||||||
|
assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model")));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
|
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,10 +5,9 @@
|
||||||
*/
|
*/
|
||||||
package org.elasticsearch.xpack.ml.inference.trainedmodels.langident;
|
package org.elasticsearch.xpack.ml.inference.trainedmodels.langident;
|
||||||
|
|
||||||
import org.elasticsearch.common.xcontent.DeprecationHandler;
|
import org.elasticsearch.action.support.PlainActionFuture;
|
||||||
|
import org.elasticsearch.client.Client;
|
||||||
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
||||||
import org.elasticsearch.common.xcontent.XContentParser;
|
|
||||||
import org.elasticsearch.common.xcontent.XContentType;
|
|
||||||
import org.elasticsearch.test.ESTestCase;
|
import org.elasticsearch.test.ESTestCase;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
|
||||||
|
@ -16,22 +15,26 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
|
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples;
|
||||||
|
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.nio.file.Files;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.hamcrest.CoreMatchers.equalTo;
|
import static org.hamcrest.CoreMatchers.equalTo;
|
||||||
import static org.hamcrest.Matchers.closeTo;
|
import static org.hamcrest.Matchers.closeTo;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
|
||||||
public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
||||||
|
|
||||||
public void testLangInference() throws Exception {
|
public void testLangInference() throws Exception {
|
||||||
TrainedModelConfig config = getLangIdentModel();
|
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
|
||||||
|
PlainActionFuture<TrainedModelConfig> future = new PlainActionFuture<>();
|
||||||
|
// Should be OK as we don't make any client calls
|
||||||
|
trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future);
|
||||||
|
TrainedModelConfig config = future.actionGet();
|
||||||
|
|
||||||
|
config.ensureParsedDefinition(xContentRegistry());
|
||||||
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
|
TrainedModelDefinition trainedModelDefinition = config.getModelDefinition();
|
||||||
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
|
List<LanguageExamples.LanguageExampleEntry> examples = new LanguageExamples().getLanguageExamples();
|
||||||
ClassificationConfig classificationConfig = new ClassificationConfig(1);
|
ClassificationConfig classificationConfig = new ClassificationConfig(1);
|
||||||
|
@ -53,19 +56,6 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private TrainedModelConfig getLangIdentModel() throws IOException {
|
|
||||||
String path = "/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json";
|
|
||||||
try(XContentParser parser =
|
|
||||||
XContentType.JSON.xContent().createParser(
|
|
||||||
NamedXContentRegistry.EMPTY,
|
|
||||||
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
|
|
||||||
Files.newInputStream(getDataPath(path)))) {
|
|
||||||
TrainedModelConfig config = TrainedModelConfig.fromXContent(parser, true).build();
|
|
||||||
config.ensureParsedDefinition(xContentRegistry());
|
|
||||||
return config;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected NamedXContentRegistry xContentRegistry() {
|
protected NamedXContentRegistry xContentRegistry() {
|
||||||
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
|
||||||
|
|
|
@ -1,18 +1,3 @@
|
||||||
---
|
|
||||||
"Test get-all given no trained models exist":
|
|
||||||
|
|
||||||
- do:
|
|
||||||
ml.get_trained_models:
|
|
||||||
model_id: "_all"
|
|
||||||
- match: { count: 0 }
|
|
||||||
- match: { trained_model_configs: [] }
|
|
||||||
|
|
||||||
- do:
|
|
||||||
ml.get_trained_models:
|
|
||||||
model_id: "*"
|
|
||||||
- match: { count: 0 }
|
|
||||||
- match: { trained_model_configs: [] }
|
|
||||||
|
|
||||||
---
|
---
|
||||||
"Test get given missing trained model":
|
"Test get given missing trained model":
|
||||||
|
|
||||||
|
@ -111,3 +96,11 @@
|
||||||
catch: conflict
|
catch: conflict
|
||||||
ml.delete_trained_model:
|
ml.delete_trained_model:
|
||||||
model_id: "used-regression-model"
|
model_id: "used-regression-model"
|
||||||
|
---
|
||||||
|
"Test get pre-packaged trained models":
|
||||||
|
- do:
|
||||||
|
ml.get_trained_models:
|
||||||
|
model_id: "_all"
|
||||||
|
allow_no_match: false
|
||||||
|
- match: { count: 1 }
|
||||||
|
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
|
||||||
|
|
|
@ -5,17 +5,15 @@ setup:
|
||||||
headers:
|
headers:
|
||||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||||
index:
|
index:
|
||||||
id: trained_model_config-unused-regression-model1-0
|
id: trained_model_config-a-unused-regression-model1-0
|
||||||
index: .ml-inference-000001
|
index: .ml-inference-000001
|
||||||
body: >
|
body: >
|
||||||
{
|
{
|
||||||
"model_id": "unused-regression-model1",
|
"model_id": "a-unused-regression-model1",
|
||||||
"created_by": "ml_tests",
|
"created_by": "ml_tests",
|
||||||
"version": "8.0.0",
|
"version": "8.0.0",
|
||||||
"description": "empty model for tests",
|
"description": "empty model for tests",
|
||||||
"create_time": 0,
|
"create_time": 0,
|
||||||
"model_version": 0,
|
|
||||||
"model_type": "local",
|
|
||||||
"doc_type": "trained_model_config"
|
"doc_type": "trained_model_config"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,34 +21,30 @@ setup:
|
||||||
headers:
|
headers:
|
||||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||||
index:
|
index:
|
||||||
id: trained_model_config-unused-regression-model-0
|
id: trained_model_config-a-unused-regression-model-0
|
||||||
index: .ml-inference-000001
|
index: .ml-inference-000001
|
||||||
body: >
|
body: >
|
||||||
{
|
{
|
||||||
"model_id": "unused-regression-model",
|
"model_id": "a-unused-regression-model",
|
||||||
"created_by": "ml_tests",
|
"created_by": "ml_tests",
|
||||||
"version": "8.0.0",
|
"version": "8.0.0",
|
||||||
"description": "empty model for tests",
|
"description": "empty model for tests",
|
||||||
"create_time": 0,
|
"create_time": 0,
|
||||||
"model_version": 0,
|
|
||||||
"model_type": "local",
|
|
||||||
"doc_type": "trained_model_config"
|
"doc_type": "trained_model_config"
|
||||||
}
|
}
|
||||||
- do:
|
- do:
|
||||||
headers:
|
headers:
|
||||||
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
|
||||||
index:
|
index:
|
||||||
id: trained_model_config-used-regression-model-0
|
id: trained_model_config-a-used-regression-model-0
|
||||||
index: .ml-inference-000001
|
index: .ml-inference-000001
|
||||||
body: >
|
body: >
|
||||||
{
|
{
|
||||||
"model_id": "used-regression-model",
|
"model_id": "a-used-regression-model",
|
||||||
"created_by": "ml_tests",
|
"created_by": "ml_tests",
|
||||||
"version": "8.0.0",
|
"version": "8.0.0",
|
||||||
"description": "empty model for tests",
|
"description": "empty model for tests",
|
||||||
"create_time": 0,
|
"create_time": 0,
|
||||||
"model_version": 0,
|
|
||||||
"model_type": "local",
|
|
||||||
"doc_type": "trained_model_config"
|
"doc_type": "trained_model_config"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +63,7 @@ setup:
|
||||||
"processors": [
|
"processors": [
|
||||||
{
|
{
|
||||||
"inference" : {
|
"inference" : {
|
||||||
"model_id" : "used-regression-model",
|
"model_id" : "a-used-regression-model",
|
||||||
"inference_config": {"regression": {}},
|
"inference_config": {"regression": {}},
|
||||||
"target_field": "regression_field",
|
"target_field": "regression_field",
|
||||||
"field_mappings": {}
|
"field_mappings": {}
|
||||||
|
@ -87,7 +81,7 @@ setup:
|
||||||
"processors": [
|
"processors": [
|
||||||
{
|
{
|
||||||
"inference" : {
|
"inference" : {
|
||||||
"model_id" : "used-regression-model",
|
"model_id" : "a-used-regression-model",
|
||||||
"inference_config": {"regression": {}},
|
"inference_config": {"regression": {}},
|
||||||
"target_field": "regression_field",
|
"target_field": "regression_field",
|
||||||
"field_mappings": {}
|
"field_mappings": {}
|
||||||
|
@ -125,18 +119,18 @@ setup:
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "unused-regression-model"
|
model_id: "a-unused-regression-model"
|
||||||
|
|
||||||
- match: { count: 1 }
|
- match: { count: 1 }
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "_all"
|
model_id: "_all"
|
||||||
- match: { count: 3 }
|
- match: { count: 4 }
|
||||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
- match: { trained_model_stats.0.model_id: a-unused-regression-model }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.0.ingest
|
- is_false: trained_model_stats.0.ingest
|
||||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
- match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
|
||||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.1.ingest
|
- is_false: trained_model_stats.1.ingest
|
||||||
- match: { trained_model_stats.2.pipeline_count: 2 }
|
- match: { trained_model_stats.2.pipeline_count: 2 }
|
||||||
|
@ -145,11 +139,11 @@ setup:
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "*"
|
model_id: "*"
|
||||||
- match: { count: 3 }
|
- match: { count: 4 }
|
||||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
- match: { trained_model_stats.0.model_id: a-unused-regression-model }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.0.ingest
|
- is_false: trained_model_stats.0.ingest
|
||||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
- match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
|
||||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.1.ingest
|
- is_false: trained_model_stats.1.ingest
|
||||||
- match: { trained_model_stats.2.pipeline_count: 2 }
|
- match: { trained_model_stats.2.pipeline_count: 2 }
|
||||||
|
@ -157,40 +151,40 @@ setup:
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "unused-regression-model*"
|
model_id: "a-unused-regression-model*"
|
||||||
- match: { count: 2 }
|
- match: { count: 2 }
|
||||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
- match: { trained_model_stats.0.model_id: a-unused-regression-model }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.0.ingest
|
- is_false: trained_model_stats.0.ingest
|
||||||
- match: { trained_model_stats.1.model_id: unused-regression-model1 }
|
- match: { trained_model_stats.1.model_id: a-unused-regression-model1 }
|
||||||
- match: { trained_model_stats.1.pipeline_count: 0 }
|
- match: { trained_model_stats.1.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.1.ingest
|
- is_false: trained_model_stats.1.ingest
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "unused-regression-model*"
|
model_id: "a-unused-regression-model*"
|
||||||
size: 1
|
size: 1
|
||||||
- match: { count: 2 }
|
- match: { count: 2 }
|
||||||
- match: { trained_model_stats.0.model_id: unused-regression-model }
|
- match: { trained_model_stats.0.model_id: a-unused-regression-model }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.0.ingest
|
- is_false: trained_model_stats.0.ingest
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "unused-regression-model*"
|
model_id: "a-unused-regression-model*"
|
||||||
from: 1
|
from: 1
|
||||||
size: 1
|
size: 1
|
||||||
- match: { count: 2 }
|
- match: { count: 2 }
|
||||||
- match: { trained_model_stats.0.model_id: unused-regression-model1 }
|
- match: { trained_model_stats.0.model_id: a-unused-regression-model1 }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 0 }
|
- match: { trained_model_stats.0.pipeline_count: 0 }
|
||||||
- is_false: trained_model_stats.0.ingest
|
- is_false: trained_model_stats.0.ingest
|
||||||
|
|
||||||
- do:
|
- do:
|
||||||
ml.get_trained_models_stats:
|
ml.get_trained_models_stats:
|
||||||
model_id: "used-regression-model"
|
model_id: "a-used-regression-model"
|
||||||
|
|
||||||
- match: { count: 1 }
|
- match: { count: 1 }
|
||||||
- match: { trained_model_stats.0.model_id: used-regression-model }
|
- match: { trained_model_stats.0.model_id: a-used-regression-model }
|
||||||
- match: { trained_model_stats.0.pipeline_count: 2 }
|
- match: { trained_model_stats.0.pipeline_count: 2 }
|
||||||
- match:
|
- match:
|
||||||
trained_model_stats.0.ingest.total:
|
trained_model_stats.0.ingest.total:
|
||||||
|
|
Loading…
Reference in New Issue