[ML][Inference] fixing ingest IT tests (#51267) (#51311)

Converts InferenceIngestIT into a `ESRestTestCase`.

closes #51201
This commit is contained in:
Benjamin Trent 2020-01-22 09:50:17 -05:00 committed by GitHub
parent b38fdf9f94
commit 2a73e849d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,141 +5,109 @@
*/ */
package org.elasticsearch.xpack.ml.integration; package org.elasticsearch.xpack.ml.integration;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.apache.http.util.EntityUtils;
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult; import org.elasticsearch.client.Request;
import org.elasticsearch.action.ingest.SimulatePipelineResponse; import org.elasticsearch.client.Response;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.ExternalTestCluster;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction; import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase { /**
* This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
* Specifically, ensuring the accounting breaker has been reset.
* It has to do with `_simulate` not anything really to do with the ML code
*/
public class InferenceIngestIT extends ESRestTestCase {
private static final String BASIC_AUTH_VALUE_SUPER_USER =
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
@Before @Before
public void createBothModels() throws Exception { public void createBothModels() throws Exception {
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet(); Request request = new Request("PUT", "_ml/inference/test_classification");
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet(); request.setJsonEntity(CLASSIFICATION_CONFIG);
client().performRequest(request);
request = new Request("PUT", "_ml/inference/test_regression");
request.setJsonEntity(REGRESSION_CONFIG);
client().performRequest(request);
}
@Override
protected Settings restClientSettings() {
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
} }
@After @After
public void deleteBothModels() { public void cleanUpData() throws Exception {
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet(); new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet(); ESRestTestCase.waitForPendingTasks(adminClient());
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
} }
public void testPipelineCreationAndDeletion() throws Exception { public void testPipelineCreationAndDeletion() throws Exception {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
XContentType.JSON).get().isAcknowledged(), is(true)); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
.setSource(new HashMap<String, Object>(){{ client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
put("col1", randomFrom("female", "male")); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_classification_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(new HashMap<String, Object>(){{
put("col1", randomFrom("female", "male"));
put("col2", randomFrom("S", "M", "L", "XL"));
put("col3", randomFrom("true", "false", "none", "other"));
put("col4", randomIntBetween(0, 10));
}})
.setPipeline("simple_regression_pipeline")
.get();
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
is(true));
} }
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline", client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)), client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
XContentType.JSON).get().isAcknowledged(), is(true));
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
XContentType.JSON).get().isAcknowledged(), is(true));
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME) client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
.setSource(generateSourceDoc()) client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
.setPipeline("simple_classification_pipeline")
.get();
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
.setSource(generateSourceDoc())
.setPipeline("simple_regression_pipeline")
.get();
} }
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(), client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
is(true)); client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(), client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
is(true));
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
assertThat(client().search(new SearchRequest().indices("index_for_inference_test") Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
.source(new SearchSourceBuilder() QueryBuilders.boolQuery()
.size(0) .filter(
.trackTotalHits(true) QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
.query(QueryBuilders.boolQuery() assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
.filter(
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))))).get().getHits().getTotalHits().value,
equalTo(20L));
assertThat(client().search(new SearchRequest().indices("index_for_inference_test") searchResponse = client().performRequest(searchRequest("index_for_inference_test",
.source(new SearchSourceBuilder() QueryBuilders.boolQuery()
.size(0) .filter(
.trackTotalHits(true) QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
.query(QueryBuilders.boolQuery()
.filter(
QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))))
.get()
.getHits()
.getTotalHits()
.value,
equalTo(20L));
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
} }
public void testSimulate() { public void testSimulate() throws IOException {
String source = "{\n" + String source = "{\n" +
" \"pipeline\": {\n" + " \"pipeline\": {\n" +
" \"processors\": [\n" + " \"processors\": [\n" +
@ -181,15 +149,10 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" + " }}]\n" +
"}"; "}";
SimulatePipelineResponse response = client().admin().cluster() Response response = client().performRequest(simulateRequest(source));
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), String responseString = EntityUtils.toString(response.getEntity());
XContentType.JSON).get(); assertThat(responseString, containsString("\"predicted_value\":\"second\""));
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0); assertThat(responseString, containsString("\"predicted_value\":1.0"));
assertThat(baseResult.getIngestDocument().getFieldValue("ml.regression.predicted_value", Double.class), equalTo(1.0));
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.predicted_value", String.class),
equalTo("second"));
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.result_class_prob", List.class).size(),
equalTo(2));
String sourceWithMissingModel = "{\n" + String sourceWithMissingModel = "{\n" +
" \"pipeline\": {\n" + " \"pipeline\": {\n" +
@ -217,15 +180,13 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" + " }}]\n" +
"}"; "}";
response = client().admin().cluster() response = client().performRequest(simulateRequest(sourceWithMissingModel));
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)), responseString = EntityUtils.toString(response.getEntity());
XContentType.JSON).get();
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(), assertThat(responseString, containsString("Could not find trained model [test_classification_missing]"));
containsString("Could not find trained model [test_classification_missing]"));
} }
public void testSimulateLangIdent() { public void testSimulateLangIdent() throws IOException {
String source = "{\n" + String source = "{\n" +
" \"pipeline\": {\n" + " \"pipeline\": {\n" +
" \"processors\": [\n" + " \"processors\": [\n" +
@ -244,11 +205,43 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }}]\n" + " }}]\n" +
"}"; "}";
SimulatePipelineResponse response = client().admin().cluster() Response response = client().performRequest(simulateRequest(source));
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)), assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
XContentType.JSON).get(); }
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en")); private static Request simulateRequest(String jsonEntity) {
Request request = new Request("POST", "_ingest/pipeline/_simulate");
request.setJsonEntity(jsonEntity);
return request;
}
private static Request indexRequest(String index, String pipeline, Map<String, Object> doc) throws IOException {
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(doc)) {
return indexRequest(index,
pipeline,
XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON));
}
}
private static Request indexRequest(String index, String pipeline, String doc) {
Request request = new Request("POST", index + "/_doc?pipeline=" + pipeline);
request.setJsonEntity(doc);
return request;
}
private static Request putPipeline(String pipelineId, String pipelineDefinition) {
Request request = new Request("PUT", "_ingest/pipeline/" + pipelineId);
request.setJsonEntity(pipelineDefinition);
return request;
}
private static Request searchRequest(String index, QueryBuilder queryBuilder) throws IOException {
BytesReference reference = XContentHelper.toXContent(queryBuilder, XContentType.JSON, false);
String queryJson = XContentHelper.convertToJson(reference, false, XContentType.JSON);
String json = "{\"query\": " + queryJson + "}";
Request request = new Request("GET", index + "/_search?track_total_hits=true");
request.setJsonEntity(json);
return request;
} }
private Map<String, Object> generateSourceDoc() { private Map<String, Object> generateSourceDoc() {
@ -380,16 +373,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
"}"; "}";
private static final String REGRESSION_CONFIG = "{" + private static final String REGRESSION_CONFIG = "{" +
" \"model_id\": \"test_regression\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for regression\",\n" + " \"description\": \"test model for regression\",\n" +
" \"version\": \"7.6.0\",\n" + " \"definition\": " + REGRESSION_DEFINITION +
" \"definition\": " + REGRESSION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"ml_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0" +
"}"; "}";
private static final String CLASSIFICATION_DEFINITION = "{" + private static final String CLASSIFICATION_DEFINITION = "{" +
@ -512,24 +498,6 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
" }\n" + " }\n" +
"}"; "}";
private TrainedModelConfig buildClassificationModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(CLASSIFICATION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
private TrainedModelConfig buildRegressionModel() throws IOException {
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(REGRESSION_CONFIG),
XContentType.JSON)) {
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
}
}
@Override @Override
protected NamedXContentRegistry xContentRegistry() { protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
@ -537,16 +505,9 @@ public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
private static final String CLASSIFICATION_CONFIG = "" + private static final String CLASSIFICATION_CONFIG = "" +
"{\n" + "{\n" +
" \"model_id\": \"test_classification\",\n" +
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
" \"description\": \"test model for classification\",\n" + " \"description\": \"test model for classification\",\n" +
" \"version\": \"7.6.0\",\n" + " \"definition\": " + CLASSIFICATION_DEFINITION +
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
" \"license_level\": \"platinum\",\n" +
" \"created_by\": \"es_test\",\n" +
" \"estimated_heap_memory_usage_bytes\": 0," +
" \"estimated_operations\": 0," +
" \"created_time\": 0\n" +
"}"; "}";
private static final String CLASSIFICATION_PIPELINE = "{" + private static final String CLASSIFICATION_PIPELINE = "{" +