From 060e0a6277d675b9e8623685c60fa872d8b10b3c Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 7 Jan 2020 09:21:59 -0500 Subject: [PATCH] [ML][Inference] Add support for models shipped as resources (#50680) (#50700) This adds support for models that are shipped as resources in the ML plugin. The first of which is the `lang_ident` model. --- .../client/MachineLearningIT.java | 10 +- .../action/GetTrainedModelsStatsAction.java | 6 + .../core/ml/inference/TrainedModelConfig.java | 5 + .../xpack/core/ml/job/messages/Messages.java | 4 +- .../ml/integration/InferenceIngestIT.java | 26 ++++ .../xpack/ml/integration/TrainedModelIT.java | 42 +++---- .../TransportGetTrainedModelsAction.java | 2 +- .../persistence/TrainedModelProvider.java | 111 +++++++++++++++++- .../TrainedModelProviderTests.java | 74 ++++++++++++ .../LangIdentNeuralNetworkInferenceTests.java | 30 ++--- .../rest-api-spec/test/ml/inference_crud.yml | 23 ++-- .../test/ml/inference_stats_crud.yml | 54 ++++----- 12 files changed, 290 insertions(+), 97 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 45cbbd632d5..4d3cae02b2a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -2198,8 +2198,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { GetTrainedModelsResponse getTrainedModelsResponse = execute( GetTrainedModelsRequest.getAllTrainedModelConfigsRequest(), machineLearningClient::getTrainedModels, machineLearningClient::getTrainedModelsAsync); - assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels)); - assertThat(getTrainedModelsResponse.getCount(), equalTo(5L)); + assertThat(getTrainedModelsResponse.getTrainedModels(), hasSize(numberOfModels + 1)); + assertThat(getTrainedModelsResponse.getCount(), equalTo(5L + 1)); } { GetTrainedModelsResponse getTrainedModelsResponse = execute( @@ -2222,7 +2222,7 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { public void testGetTrainedModelsStats() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String modelIdPrefix = "get-trained-model-stats-"; + String modelIdPrefix = "a-get-trained-model-stats-"; int numberOfModels = 5; for (int i = 0; i < numberOfModels; ++i) { String modelId = modelIdPrefix + i; @@ -2254,8 +2254,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase { GetTrainedModelsStatsResponse getTrainedModelsStatsResponse = execute( GetTrainedModelsStatsRequest.getAllTrainedModelStatsRequest(), machineLearningClient::getTrainedModelsStats, machineLearningClient::getTrainedModelsStatsAsync); - assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels)); - assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L)); + assertThat(getTrainedModelsStatsResponse.getTrainedModelStats(), hasSize(numberOfModels + 1)); + assertThat(getTrainedModelsStatsResponse.getCount(), equalTo(5L + 1)); assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(0).getPipelineCount(), equalTo(1)); assertThat(getTrainedModelsStatsResponse.getTrainedModelStats().get(1).getPipelineCount(), equalTo(0)); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index f3cb43e8ef7..3e91fd0444b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -98,6 +99,10 @@ public class GetTrainedModelsStatsAction extends ActionType(trainedModelStats, totalModelCount, RESULTS_FIELD)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index d0056dce734..be4d40efc85 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -409,6 +409,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { return this; } + public Builder clearDefinition() { + this.definition = null; + return this; + } + private Builder setLazyDefinition(TrainedModelDefinition.Builder parsedTrainedModel) { if (parsedTrainedModel == null) { return this; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 5e7d3ee3318..1b80d196359 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -87,10 +87,12 @@ public final class Messages { public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; + public static final String INFERENCE_CANNOT_DELETE_MODEL = + "Unable to delete model [{0}]"; public static final String MODEL_DEFINITION_TRUNCATED = "Model definition truncated. Unable to deserialize trained model definition [{0}]"; public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]"; - public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED = + public static final String INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED = "Getting model definition is not supported when getting more than one model"; public static final String INFERENCE_WARNING_ALL_FIELDS_MISSING = "Model [{0}] could not be inferred as all fields were missing"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index d968903f236..20aea0f6316 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -233,6 +233,32 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { containsString("Could not find trained model [test_classification_missing]")); } + public void testSimulateLangIdent() { + String source = "{\n" + + " \"pipeline\": {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\": {\n" + + " \"inference_config\": {\"classification\":{}},\n" + + " \"model_id\": \"lang_ident_model_1\",\n" + + " \"field_mappings\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"text\": \"this is some plain text.\"\n" + + " }}]\n" + + "}"; + + SimulatePipelineResponse response = client().admin().cluster() + .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), + XContentType.JSON).get(); + SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); + assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en")); + } + private Map generateSourceDoc() { return new HashMap(){{ put("col1", randomFrom("female", "male")); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 5502a6c2c80..5911732257e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -60,8 +60,8 @@ public class TrainedModelIT extends ESRestTestCase { } public void testGetTrainedModels() throws IOException { - String modelId = "test_regression_model"; - String modelId2 = "test_regression_model-2"; + String modelId = "a_test_regression_model"; + String modelId2 = "a_test_regression_model-2"; Request model1 = new Request("PUT", InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId); model1.setJsonEntity(buildRegressionModel(modelId)); @@ -84,36 +84,36 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); String response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression*")); + MachineLearning.BASE_PATH + "inference/a_test_regression*")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); - assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\"")); assertThat(response, not(containsString("\"definition\""))); assertThat(response, containsString("\"count\":2")); getModel = client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true")); + MachineLearning.BASE_PATH + "inference/a_test_regression_model?human=true&include_model_definition=true")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"estimated_heap_memory_usage\"")); assertThat(response, containsString("\"definition\"")); assertThat(response, containsString("\"count\":1")); getModel = client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression_model?decompress_definition=false&include_model_definition=true")); + MachineLearning.BASE_PATH + "inference/a_test_regression_model?decompress_definition=false&include_model_definition=true")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); assertThat(response, containsString("\"estimated_heap_memory_usage_bytes\"")); assertThat(response, containsString("\"compressed_definition\"")); assertThat(response, not(containsString("\"definition\""))); @@ -121,17 +121,17 @@ public class TrainedModelIT extends ESRestTestCase { ResponseException responseException = expectThrows(ResponseException.class, () -> client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true"))); + MachineLearning.BASE_PATH + "inference/a_test_regression*?human=true&include_model_definition=true"))); assertThat(EntityUtils.toString(responseException.getResponse().getEntity()), - containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED)); + containsString(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED)); getModel = client().performRequest(new Request("GET", - MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2")); + MachineLearning.BASE_PATH + "inference/a_test_regression_model,a_test_regression_model-2")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); - assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\"")); assertThat(response, containsString("\"count\":2")); getModel = client().performRequest(new Request("GET", @@ -149,17 +149,17 @@ public class TrainedModelIT extends ESRestTestCase { assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"count\":2")); - assertThat(response, containsString("\"model_id\":\"test_regression_model\"")); - assertThat(response, not(containsString("\"model_id\":\"test_regression_model-2\""))); + assertThat(response, containsString("\"count\":3")); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model\"")); + assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model-2\""))); getModel = client().performRequest(new Request("GET", MachineLearning.BASE_PATH + "inference?from=1&size=1")); assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200)); response = EntityUtils.toString(getModel.getEntity()); - assertThat(response, containsString("\"count\":2")); - assertThat(response, not(containsString("\"model_id\":\"test_regression_model\""))); - assertThat(response, containsString("\"model_id\":\"test_regression_model-2\"")); + assertThat(response, containsString("\"count\":3")); + assertThat(response, not(containsString("\"model_id\":\"a_test_regression_model\""))); + assertThat(response, containsString("\"model_id\":\"a_test_regression_model-2\"")); } public void testDeleteTrainedModels() throws IOException { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index 15629579368..4b467912e1b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -50,7 +50,7 @@ public class TransportGetTrainedModelsAction extends HandledTransportAction 1) { listener.onFailure( - ExceptionsHelper.badRequestException(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED) + ExceptionsHelper.badRequestException(Messages.INFERENCE_TOO_MANY_DEFINITIONS_REQUESTED) ); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index ec9e133aceb..2c14fd70f10 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.ml.inference.persistence; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; @@ -31,6 +32,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; @@ -39,6 +41,7 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.IndexNotFoundException; @@ -65,8 +68,10 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.io.InputStream; +import java.net.URL; import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; @@ -79,6 +84,10 @@ import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INFERENCE_FA public class TrainedModelProvider { + public static final Set MODELS_STORED_AS_RESOURCE = Collections.singleton("lang_ident_model_1"); + private static final String MODEL_RESOURCE_PATH = "/org/elasticsearch/xpack/ml/inference/persistence/"; + private static final String MODEL_RESOURCE_FILE_EXT = ".json"; + private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; private final NamedXContentRegistry xContentRegistry; @@ -92,6 +101,12 @@ public class TrainedModelProvider { public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + return; + } + try { trainedModelConfig.ensureParsedDefinition(xContentRegistry); } catch (IOException ex) { @@ -185,6 +200,16 @@ public class TrainedModelProvider { public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { + try { + listener.onResponse(loadModelFromResource(modelId, includeDefinition == false)); + return; + } catch (ElasticsearchException ex) { + listener.onFailure(ex); + return; + } + } + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders .idsQuery() .addIds(modelId)); @@ -268,11 +293,29 @@ public class TrainedModelProvider { .addSort("_index", SortOrder.DESC) .setQuery(queryBuilder) .request(); + List configs = new ArrayList<>(modelIds.size()); + Set modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); + Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); + for(String modelId : modelsAsResource) { + try { + configs.add(loadModelFromResource(modelId, true)); + } catch (ElasticsearchException ex) { + listener.onFailure(ex); + return; + } + } + if (modelsInIndex.isEmpty()) { + configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); + listener.onResponse(configs); + return; + } ActionListener configSearchHandler = ActionListener.wrap( searchResponse -> { - Set observedIds = new HashSet<>(searchResponse.getHits().getHits().length, 1.0f); - List configs = new ArrayList<>(searchResponse.getHits().getHits().length); + Set observedIds = new HashSet<>( + searchResponse.getHits().getHits().length + modelsAsResource.size(), + 1.0f); + observedIds.addAll(modelsAsResource); for(SearchHit searchHit : searchResponse.getHits().getHits()) { try { if (observedIds.contains(searchHit.getId()) == false) { @@ -295,6 +338,8 @@ public class TrainedModelProvider { listener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); return; } + // Ensure sorted even with the injection of locally resourced models + configs.sort(Comparator.comparing(TrainedModelConfig::getModelId)); listener.onResponse(configs); }, listener::onFailure @@ -304,6 +349,10 @@ public class TrainedModelProvider { } public void deleteTrainedModel(String modelId, ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { + listener.onFailure(ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, modelId))); + return; + } DeleteByQueryRequest request = new DeleteByQueryRequest().setAbortOnVersionConflict(false); request.indices(InferenceIndexConstants.INDEX_PATTERN); @@ -360,8 +409,8 @@ public class TrainedModelProvider { searchRequest, ActionListener.wrap( response -> { - Set foundResourceIds = new LinkedHashSet<>(); - long totalHitCount = response.getHits().getTotalHits().value; + Set foundResourceIds = new LinkedHashSet<>(matchedResourceIds(tokens)); + long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size(); for (SearchHit hit : response.getHits().getHits()) { Map docSource = hit.getSourceAsMap(); if (docSource == null) { @@ -386,6 +435,37 @@ public class TrainedModelProvider { } + TrainedModelConfig loadModelFromResource(String modelId, boolean nullOutDefinition) { + URL resource = getClass().getResource(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT); + if (resource == null) { + logger.error("[{}] presumed stored as a resource but not found", modelId); + throw new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId)); + } + try { + BytesReference bytes = Streams.readFully(getClass() + .getResourceAsStream(MODEL_RESOURCE_PATH + modelId + MODEL_RESOURCE_FILE_EXT)); + try (XContentParser parser = + XContentHelper.createParser(xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + bytes, + XContentType.JSON)) { + TrainedModelConfig.Builder builder = TrainedModelConfig.fromXContent(parser, true); + if (nullOutDefinition) { + builder.clearDefinition(); + } + return builder.build(); + } catch (IOException ioEx) { + logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), ioEx); + throw ExceptionsHelper.serverError(INFERENCE_FAILED_TO_DESERIALIZE, ioEx, modelId); + } + } catch (IOException ex) { + String msg = new ParameterizedMessage("[{}] failed to read model as resource", modelId).getFormattedMessage(); + logger.error(msg, ex); + throw ExceptionsHelper.serverError(msg, ex); + } + } + private QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String resourceIdField) { BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelConfig.NAME)); @@ -414,6 +494,29 @@ public class TrainedModelProvider { return boolQuery; } + private Set matchedResourceIds(String[] tokens) { + if (Strings.isAllOrWildcard(tokens)) { + return new HashSet<>(MODELS_STORED_AS_RESOURCE); + } + + Set matchedModels = new HashSet<>(); + + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + for (String modelId : MODELS_STORED_AS_RESOURCE) { + if(Regex.simpleMatch(token, modelId)) { + matchedModels.add(modelId); + } + } + } else { + if (MODELS_STORED_AS_RESOURCE.contains(token)) { + matchedModels.add(token); + } + } + } + return matchedModels; + } + private static T handleSearchItem(MultiSearchResponse.Item item, String resourceId, CheckedBiFunction parseLeniently) throws Exception { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java new file mode 100644 index 00000000000..705a8f60dd2 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.persistence; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +public class TrainedModelProviderTests extends ESTestCase { + + public void testDeleteModelStoredAsResource() { + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + PlainActionFuture future = new PlainActionFuture<>(); + // Should be OK as we don't make any client calls + trainedModelProvider.deleteTrainedModel("lang_ident_model_1", future); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_CANNOT_DELETE_MODEL, "lang_ident_model_1"))); + } + + public void testPutModelThatExistsAsResource() { + TrainedModelConfig config = TrainedModelConfigTests.createTestInstance("lang_ident_model_1").build(); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + PlainActionFuture future = new PlainActionFuture<>(); + trainedModelProvider.storeTrainedModel(config, future); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, future::actionGet); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, "lang_ident_model_1"))); + } + + public void testGetModelThatExistsAsResource() throws Exception { + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + for(String modelId : TrainedModelProvider.MODELS_STORED_AS_RESOURCE) { + PlainActionFuture future = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModel(modelId, true, future); + TrainedModelConfig configWithDefinition = future.actionGet(); + + assertThat(configWithDefinition.getModelId(), equalTo(modelId)); + assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue()))); + + PlainActionFuture futureNoDefinition = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModel(modelId, false, futureNoDefinition); + TrainedModelConfig configWithoutDefinition = futureNoDefinition.actionGet(); + + assertThat(configWithoutDefinition.getModelId(), equalTo(modelId)); + assertThat(configWithDefinition.ensureParsedDefinition(xContentRegistry()).getModelDefinition(), is(not(nullValue()))); + } + } + + public void testGetModelThatExistsAsResourceButIsMissing() { + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, + () -> trainedModelProvider.loadModelFromResource("missing_model", randomBoolean())); + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, "missing_model"))); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java index c93d2d58dde..8357f1cb782 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/trainedmodels/langident/LangIdentNeuralNetworkInferenceTests.java @@ -5,10 +5,9 @@ */ package org.elasticsearch.xpack.ml.inference.trainedmodels.langident; -import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Client; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; @@ -16,22 +15,26 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LanguageExamples; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; -import java.io.IOException; -import java.nio.file.Files; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.closeTo; - +import static org.mockito.Mockito.mock; public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { public void testLangInference() throws Exception { - TrainedModelConfig config = getLangIdentModel(); + TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); + PlainActionFuture future = new PlainActionFuture<>(); + // Should be OK as we don't make any client calls + trainedModelProvider.getTrainedModel("lang_ident_model_1", true, future); + TrainedModelConfig config = future.actionGet(); + config.ensureParsedDefinition(xContentRegistry()); TrainedModelDefinition trainedModelDefinition = config.getModelDefinition(); List examples = new LanguageExamples().getLanguageExamples(); ClassificationConfig classificationConfig = new ClassificationConfig(1); @@ -53,19 +56,6 @@ public class LangIdentNeuralNetworkInferenceTests extends ESTestCase { } } - private TrainedModelConfig getLangIdentModel() throws IOException { - String path = "/org/elasticsearch/xpack/ml/inference/persistence/lang_ident_model_1.json"; - try(XContentParser parser = - XContentType.JSON.xContent().createParser( - NamedXContentRegistry.EMPTY, - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - Files.newInputStream(getDataPath(path)))) { - TrainedModelConfig config = TrainedModelConfig.fromXContent(parser, true).build(); - config.ensureParsedDefinition(xContentRegistry()); - return config; - } - } - @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index a8b199a7a3b..f72fd1120d8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -1,18 +1,3 @@ ---- -"Test get-all given no trained models exist": - - - do: - ml.get_trained_models: - model_id: "_all" - - match: { count: 0 } - - match: { trained_model_configs: [] } - - - do: - ml.get_trained_models: - model_id: "*" - - match: { count: 0 } - - match: { trained_model_configs: [] } - --- "Test get given missing trained model": @@ -111,3 +96,11 @@ catch: conflict ml.delete_trained_model: model_id: "used-regression-model" +--- +"Test get pre-packaged trained models": + - do: + ml.get_trained_models: + model_id: "_all" + allow_no_match: false + - match: { count: 1 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index 6062f651906..5143a690b01 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -5,17 +5,15 @@ setup: headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: - id: trained_model_config-unused-regression-model1-0 + id: trained_model_config-a-unused-regression-model1-0 index: .ml-inference-000001 body: > { - "model_id": "unused-regression-model1", + "model_id": "a-unused-regression-model1", "created_by": "ml_tests", "version": "8.0.0", "description": "empty model for tests", "create_time": 0, - "model_version": 0, - "model_type": "local", "doc_type": "trained_model_config" } @@ -23,34 +21,30 @@ setup: headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: - id: trained_model_config-unused-regression-model-0 + id: trained_model_config-a-unused-regression-model-0 index: .ml-inference-000001 body: > { - "model_id": "unused-regression-model", + "model_id": "a-unused-regression-model", "created_by": "ml_tests", "version": "8.0.0", "description": "empty model for tests", "create_time": 0, - "model_version": 0, - "model_type": "local", "doc_type": "trained_model_config" } - do: headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: - id: trained_model_config-used-regression-model-0 + id: trained_model_config-a-used-regression-model-0 index: .ml-inference-000001 body: > { - "model_id": "used-regression-model", + "model_id": "a-used-regression-model", "created_by": "ml_tests", "version": "8.0.0", "description": "empty model for tests", "create_time": 0, - "model_version": 0, - "model_type": "local", "doc_type": "trained_model_config" } @@ -69,7 +63,7 @@ setup: "processors": [ { "inference" : { - "model_id" : "used-regression-model", + "model_id" : "a-used-regression-model", "inference_config": {"regression": {}}, "target_field": "regression_field", "field_mappings": {} @@ -87,7 +81,7 @@ setup: "processors": [ { "inference" : { - "model_id" : "used-regression-model", + "model_id" : "a-used-regression-model", "inference_config": {"regression": {}}, "target_field": "regression_field", "field_mappings": {} @@ -125,18 +119,18 @@ setup: - do: ml.get_trained_models_stats: - model_id: "unused-regression-model" + model_id: "a-unused-regression-model" - match: { count: 1 } - do: ml.get_trained_models_stats: model_id: "_all" - - match: { count: 3 } - - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { count: 4 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } - is_false: trained_model_stats.0.ingest - - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.model_id: a-unused-regression-model1 } - match: { trained_model_stats.1.pipeline_count: 0 } - is_false: trained_model_stats.1.ingest - match: { trained_model_stats.2.pipeline_count: 2 } @@ -145,11 +139,11 @@ setup: - do: ml.get_trained_models_stats: model_id: "*" - - match: { count: 3 } - - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { count: 4 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } - is_false: trained_model_stats.0.ingest - - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.model_id: a-unused-regression-model1 } - match: { trained_model_stats.1.pipeline_count: 0 } - is_false: trained_model_stats.1.ingest - match: { trained_model_stats.2.pipeline_count: 2 } @@ -157,40 +151,40 @@ setup: - do: ml.get_trained_models_stats: - model_id: "unused-regression-model*" + model_id: "a-unused-regression-model*" - match: { count: 2 } - - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.model_id: a-unused-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } - is_false: trained_model_stats.0.ingest - - match: { trained_model_stats.1.model_id: unused-regression-model1 } + - match: { trained_model_stats.1.model_id: a-unused-regression-model1 } - match: { trained_model_stats.1.pipeline_count: 0 } - is_false: trained_model_stats.1.ingest - do: ml.get_trained_models_stats: - model_id: "unused-regression-model*" + model_id: "a-unused-regression-model*" size: 1 - match: { count: 2 } - - match: { trained_model_stats.0.model_id: unused-regression-model } + - match: { trained_model_stats.0.model_id: a-unused-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } - is_false: trained_model_stats.0.ingest - do: ml.get_trained_models_stats: - model_id: "unused-regression-model*" + model_id: "a-unused-regression-model*" from: 1 size: 1 - match: { count: 2 } - - match: { trained_model_stats.0.model_id: unused-regression-model1 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model1 } - match: { trained_model_stats.0.pipeline_count: 0 } - is_false: trained_model_stats.0.ingest - do: ml.get_trained_models_stats: - model_id: "used-regression-model" + model_id: "a-used-regression-model" - match: { count: 1 } - - match: { trained_model_stats.0.model_id: used-regression-model } + - match: { trained_model_stats.0.model_id: a-used-regression-model } - match: { trained_model_stats.0.pipeline_count: 2 } - match: trained_model_stats.0.ingest.total: