From 2a73e849d66a073dd66e55147fc3a95f5b681b79 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 22 Jan 2020 09:50:17 -0500 Subject: [PATCH] [ML][Inference] fixing ingest IT tests (#51267) (#51311) Converts InferenceIngestIT into a `ESRestTestCase`. closes #51201 --- .../ml/integration/InferenceIngestIT.java | 261 ++++++++---------- 1 file changed, 111 insertions(+), 150 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 9f164e40646..efe2bb2c95f 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -5,141 +5,109 @@ */ package org.elasticsearch.xpack.ml.integration; -import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; -import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; -import org.elasticsearch.action.ingest.SimulatePipelineResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; +import org.elasticsearch.test.ExternalTestCluster; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.junit.After; import org.junit.Before; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.HashMap; -import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; import static org.hamcrest.CoreMatchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { +/** + * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems + * Specifically, ensuring the accounting breaker has been reset. + * It has to do with `_simulate` not anything really to do with the ML code + */ +public class InferenceIngestIT extends ESRestTestCase { + + private static final String BASIC_AUTH_VALUE_SUPER_USER = + basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); @Before public void createBothModels() throws Exception { - client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet(); - client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet(); + Request request = new Request("PUT", "_ml/inference/test_classification"); + request.setJsonEntity(CLASSIFICATION_CONFIG); + client().performRequest(request); + + request = new Request("PUT", "_ml/inference/test_regression"); + request.setJsonEntity(REGRESSION_CONFIG); + client().performRequest(request); + } + + @Override + protected Settings restClientSettings() { + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build(); } @After - public void deleteBothModels() { - client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet(); - client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet(); + public void cleanUpData() throws Exception { + new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata(); + ESRestTestCase.waitForPendingTasks(adminClient()); + client().performRequest(new Request("DELETE", "_ml/inference/test_classification")); + client().performRequest(new Request("DELETE", "_ml/inference/test_regression")); } public void testPipelineCreationAndDeletion() throws Exception { for (int i = 0; i < 10; i++) { - assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", - new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get().isAcknowledged(), is(true)); + client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); + client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc())); + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline")); - client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) - .setSource(new HashMap(){{ - put("col1", randomFrom("female", "male")); - put("col2", randomFrom("S", "M", "L", "XL")); - put("col3", randomFrom("true", "false", "none", "other")); - put("col4", randomIntBetween(0, 10)); - }}) - .setPipeline("simple_classification_pipeline") - .get(); - - assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), - is(true)); - - assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", - new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get().isAcknowledged(), is(true)); - - client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) - .setSource(new HashMap(){{ - put("col1", randomFrom("female", "male")); - put("col2", randomFrom("S", "M", "L", "XL")); - put("col3", randomFrom("true", "false", "none", "other")); - put("col4", randomIntBetween(0, 10)); - }}) - .setPipeline("simple_regression_pipeline") - .get(); - - assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), - is(true)); + client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE)); + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); } - assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", - new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get().isAcknowledged(), is(true)); - - assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline", - new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get().isAcknowledged(), is(true)); + client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE)); + client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE)); for (int i = 0; i < 10; i++) { - client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) - .setSource(generateSourceDoc()) - .setPipeline("simple_classification_pipeline") - .get(); - - client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) - .setSource(generateSourceDoc()) - .setPipeline("simple_regression_pipeline") - .get(); + client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc())); + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); } - assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), - is(true)); + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline")); - assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), - is(true)); + client().performRequest(new Request("POST", "index_for_inference_test/_refresh")); - client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get(); - assertThat(client().search(new SearchRequest().indices("index_for_inference_test") - .source(new SearchSourceBuilder() - .size(0) - .trackTotalHits(true) - .query(QueryBuilders.boolQuery() - .filter( - QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))))).get().getHits().getTotalHits().value, - equalTo(20L)); + Response searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.existsQuery("ml.inference.regression.predicted_value")))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); - assertThat(client().search(new SearchRequest().indices("index_for_inference_test") - .source(new SearchSourceBuilder() - .size(0) - .trackTotalHits(true) - .query(QueryBuilders.boolQuery() - .filter( - QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))))) - .get() - .getHits() - .getTotalHits() - .value, - equalTo(20L)); + searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); } - public void testSimulate() { + public void testSimulate() throws IOException { String source = "{\n" + " \"pipeline\": {\n" + " \"processors\": [\n" + @@ -181,15 +149,10 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " }}]\n" + "}"; - SimulatePipelineResponse response = client().admin().cluster() - .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get(); - SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); - assertThat(baseResult.getIngestDocument().getFieldValue("ml.regression.predicted_value", Double.class), equalTo(1.0)); - assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.predicted_value", String.class), - equalTo("second")); - assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.result_class_prob", List.class).size(), - equalTo(2)); + Response response = client().performRequest(simulateRequest(source)); + String responseString = EntityUtils.toString(response.getEntity()); + assertThat(responseString, containsString("\"predicted_value\":\"second\"")); + assertThat(responseString, containsString("\"predicted_value\":1.0")); String sourceWithMissingModel = "{\n" + " \"pipeline\": {\n" + @@ -217,15 +180,13 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " }}]\n" + "}"; - response = client().admin().cluster() - .prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get(); + response = client().performRequest(simulateRequest(sourceWithMissingModel)); + responseString = EntityUtils.toString(response.getEntity()); - assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), - containsString("Could not find trained model [test_classification_missing]")); + assertThat(responseString, containsString("Could not find trained model [test_classification_missing]")); } - public void testSimulateLangIdent() { + public void testSimulateLangIdent() throws IOException { String source = "{\n" + " \"pipeline\": {\n" + " \"processors\": [\n" + @@ -244,11 +205,43 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " }}]\n" + "}"; - SimulatePipelineResponse response = client().admin().cluster() - .prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), - XContentType.JSON).get(); - SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); - assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en")); + Response response = client().performRequest(simulateRequest(source)); + assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\"")); + } + + private static Request simulateRequest(String jsonEntity) { + Request request = new Request("POST", "_ingest/pipeline/_simulate"); + request.setJsonEntity(jsonEntity); + return request; + } + + private static Request indexRequest(String index, String pipeline, Map doc) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(doc)) { + return indexRequest(index, + pipeline, + XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON)); + } + } + + private static Request indexRequest(String index, String pipeline, String doc) { + Request request = new Request("POST", index + "/_doc?pipeline=" + pipeline); + request.setJsonEntity(doc); + return request; + } + + private static Request putPipeline(String pipelineId, String pipelineDefinition) { + Request request = new Request("PUT", "_ingest/pipeline/" + pipelineId); + request.setJsonEntity(pipelineDefinition); + return request; + } + + private static Request searchRequest(String index, QueryBuilder queryBuilder) throws IOException { + BytesReference reference = XContentHelper.toXContent(queryBuilder, XContentType.JSON, false); + String queryJson = XContentHelper.convertToJson(reference, false, XContentType.JSON); + String json = "{\"query\": " + queryJson + "}"; + Request request = new Request("GET", index + "/_search?track_total_hits=true"); + request.setJsonEntity(json); + return request; } private Map generateSourceDoc() { @@ -380,16 +373,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { "}"; private static final String REGRESSION_CONFIG = "{" + - " \"model_id\": \"test_regression\",\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + - " \"version\": \"7.6.0\",\n" + - " \"definition\": " + REGRESSION_DEFINITION + ","+ - " \"license_level\": \"platinum\",\n" + - " \"created_by\": \"ml_test\",\n" + - " \"estimated_heap_memory_usage_bytes\": 0," + - " \"estimated_operations\": 0," + - " \"created_time\": 0" + + " \"definition\": " + REGRESSION_DEFINITION + "}"; private static final String CLASSIFICATION_DEFINITION = "{" + @@ -512,24 +498,6 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { " }\n" + "}"; - private TrainedModelConfig buildClassificationModel() throws IOException { - try (XContentParser parser = XContentHelper.createParser(xContentRegistry(), - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(CLASSIFICATION_CONFIG), - XContentType.JSON)) { - return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); - } - } - - private TrainedModelConfig buildRegressionModel() throws IOException { - try (XContentParser parser = XContentHelper.createParser(xContentRegistry(), - DeprecationHandler.THROW_UNSUPPORTED_OPERATION, - new BytesArray(REGRESSION_CONFIG), - XContentType.JSON)) { - return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build(); - } - } - @Override protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); @@ -537,16 +505,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { private static final String CLASSIFICATION_CONFIG = "" + "{\n" + - " \"model_id\": \"test_classification\",\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + - " \"version\": \"7.6.0\",\n" + - " \"definition\": " + CLASSIFICATION_DEFINITION + ","+ - " \"license_level\": \"platinum\",\n" + - " \"created_by\": \"es_test\",\n" + - " \"estimated_heap_memory_usage_bytes\": 0," + - " \"estimated_operations\": 0," + - " \"created_time\": 0\n" + + " \"definition\": " + CLASSIFICATION_DEFINITION + "}"; private static final String CLASSIFICATION_PIPELINE = "{" +