From a72d7cc76a98546acc0426f02827d303e2a6b6a8 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 6 Oct 2020 09:15:29 -0400 Subject: [PATCH] [ML] prefer secondary auth headers on data frame analytics _explain (#63281) (#63323) We should prefer secondary auth headers when calling _explain --- .../ExplainDataFrameAnalyticsRestIT.java | 196 ++++++++++++++++++ ...nsportExplainDataFrameAnalyticsAction.java | 63 ++++-- 2 files changed, 246 insertions(+), 13 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsRestIT.java diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsRestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsRestIT.java new file mode 100644 index 00000000000..d7a451cf96d --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsRestIT.java @@ -0,0 +1,196 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class ExplainDataFrameAnalyticsRestIT extends ESRestTestCase { + + private static String basicAuth(String user) { + return UsernamePasswordToken.basicAuthHeaderValue(user, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING); + } + + private static final String SUPER_USER = "x_pack_rest_user"; + private static final String ML_ADMIN = "ml_admin"; + private static final String BASIC_AUTH_VALUE_SUPER_USER = basicAuth(SUPER_USER); + private static final String AUTH_KEY = "Authorization"; + private static final String SECONDARY_AUTH_KEY = "es-secondary-authorization"; + + private static RequestOptions.Builder addAuthHeader(RequestOptions.Builder builder, String user) { + builder.addHeader(AUTH_KEY, basicAuth(user)); + return builder; + } + + private static RequestOptions.Builder addSecondaryAuthHeader(RequestOptions.Builder builder, String user) { + builder.addHeader(SECONDARY_AUTH_KEY, basicAuth(user)); + return builder; + } + + @Override + protected Settings restClientSettings() { + return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build(); + } + + private void setupUser(String user, List roles) throws IOException { + String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars()); + + Request request = new Request("PUT", "/_security/user/" + user); + request.setJsonEntity("{" + + " \"password\" : \"" + password + "\"," + + " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]" + + "}"); + client().performRequest(request); + } + + @Before + public void setUpData() throws Exception { + // This user has admin rights on machine learning, but (importantly for the tests) no rights + // on any of the data indexes + setupUser(ML_ADMIN, Collections.singletonList("machine_learning_admin")); + addAirlineData(); + } + + private void addAirlineData() throws IOException { + StringBuilder bulk = new StringBuilder(); + + // Create index with source = enabled, doc_values = enabled, stored = false + multi-field + Request createAirlineDataRequest = new Request("PUT", "/airline-data"); + createAirlineDataRequest.setJsonEntity("{" + + " \"mappings\": {" + + " \"properties\": {" + + " \"time stamp\": { \"type\":\"date\"}," // space in 'time stamp' is intentional + + " \"airline\": {" + + " \"type\":\"keyword\"" + + " }," + + " \"responsetime\": { \"type\":\"float\"}" + + " }" + + " }" + + "}"); + client().performRequest(createAirlineDataRequest); + + bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 1}}\n"); + bulk.append("{\"time stamp\":\"2016-06-01T00:00:00Z\",\"airline\":\"AAA\",\"responsetime\":135.22}\n"); + bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 2}}\n"); + bulk.append("{\"time stamp\":\"2016-06-01T01:59:00Z\",\"airline\":\"AAA\",\"responsetime\":541.76}\n"); + + bulkIndex(bulk.toString()); + } + + public void testExplain_GivenSecondaryHeadersAndConfig() throws IOException { + String config = "{\n" + + " \"source\": {\n" + + " \"index\": \"airline-data\"\n" + + " },\n" + + " \"analysis\": {\n" + + " \"regression\": {\n" + + " \"dependent_variable\": \"responsetime\"\n" + + " }\n" + + " }\n" + + "}"; + + + { // Request with secondary headers without perms + Request explain = explainRequestViaConfig(config); + RequestOptions.Builder options = explain.getOptions().toBuilder(); + addAuthHeader(options, SUPER_USER); + addSecondaryAuthHeader(options, ML_ADMIN); + explain.setOptions(options); + // Should throw + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain)); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403)); + } + { // request with secondary headers with perms + Request explain = explainRequestViaConfig(config); + RequestOptions.Builder options = explain.getOptions().toBuilder(); + addAuthHeader(options, ML_ADMIN); + addSecondaryAuthHeader(options, SUPER_USER); + explain.setOptions(options); + // Should not throw + client().performRequest(explain); + } + } + + public void testExplain_GivenSecondaryHeadersAndPreviouslyStoredConfig() throws IOException { + String config = "{\n" + + " \"source\": {\n" + + " \"index\": \"airline-data\"\n" + + " },\n" + + " \"dest\": {\n" + + " \"index\": \"response_prediction\"\n" + + " },\n" + + " \"analysis\":\n" + + " {\n" + + " \"regression\": {\n" + + " \"dependent_variable\": \"responsetime\"\n" + + " }\n" + + " }\n" + + "}"; + + String configId = "explain_test"; + + Request storeConfig = new Request("PUT", "_ml/data_frame/analytics/" + configId); + storeConfig.setJsonEntity(config); + client().performRequest(storeConfig); + + { // Request with secondary headers without perms + Request explain = explainRequestConfigId(configId); + RequestOptions.Builder options = explain.getOptions().toBuilder(); + addAuthHeader(options, SUPER_USER); + addSecondaryAuthHeader(options, ML_ADMIN); + explain.setOptions(options); + // Should throw + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain)); + assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403)); + } + { // request with secondary headers with perms + Request explain = explainRequestConfigId(configId); + RequestOptions.Builder options = explain.getOptions().toBuilder(); + addAuthHeader(options, ML_ADMIN); + addSecondaryAuthHeader(options, SUPER_USER); + explain.setOptions(options); + // Should not throw + client().performRequest(explain); + } + } + + private static Request explainRequestViaConfig(String config) { + Request request = new Request("POST", "_ml/data_frame/analytics/_explain"); + request.setJsonEntity(config); + return request; + } + + private static Request explainRequestConfigId(String id) { + return new Request("POST", "_ml/data_frame/analytics/" + id + "/_explain"); + } + + private void bulkIndex(String bulk) throws IOException { + Request bulkRequest = new Request("POST", "/_bulk"); + bulkRequest.setJsonEntity(bulk); + bulkRequest.addParameter("refresh", "true"); + bulkRequest.addParameter("pretty", null); + String bulkResponse = EntityUtils.toString(client().performRequest(bulkRequest).getEntity()); + assertThat(bulkResponse, not(containsString("\"errors\": false"))); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java index 44cad0119b8..c0ccd42d38a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.java @@ -16,16 +16,21 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection; import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector; @@ -37,6 +42,9 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; +import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable; + /** * Provides explanations on aspects of the given data frame analytics spec like memory estimation, field selection, etc. * Redirects to a different node if the current node is *not* an ML node. @@ -49,6 +57,8 @@ public class TransportExplainDataFrameAnalyticsAction private final ClusterService clusterService; private final NodeClient client; private final MemoryUsageEstimationProcessManager processManager; + private final SecurityContext securityContext; + private final ThreadPool threadPool; @Inject public TransportExplainDataFrameAnalyticsAction(TransportService transportService, @@ -56,13 +66,19 @@ public class TransportExplainDataFrameAnalyticsAction ClusterService clusterService, NodeClient client, XPackLicenseState licenseState, - MemoryUsageEstimationProcessManager processManager) { + MemoryUsageEstimationProcessManager processManager, + Settings settings, + ThreadPool threadPool) { super(ExplainDataFrameAnalyticsAction.NAME, transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new); this.transportService = transportService; this.clusterService = Objects.requireNonNull(clusterService); this.client = Objects.requireNonNull(client); this.licenseState = licenseState; this.processManager = Objects.requireNonNull(processManager); + this.threadPool = threadPool; + this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ? + new SecurityContext(settings, threadPool.getThreadContext()) : + null; } @Override @@ -84,17 +100,38 @@ public class TransportExplainDataFrameAnalyticsAction private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ActionListener listener) { - ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = - new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId())); - extractedFieldsDetectorFactory.createFromSource( - request.getConfig(), - ActionListener.wrap( - extractedFieldsDetector -> explain(task, request, extractedFieldsDetector, listener), - listener::onFailure) + + final ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory( + new ParentTaskAssigningClient(client, task.getParentTaskId()) ); + if (licenseState.isSecurityEnabled()) { + useSecondaryAuthIfAvailable(this.securityContext, () -> { + // Set the auth headers (preferring the secondary headers) to the caller's. + // Regardless if the config was previously stored or not. + DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder(request.getConfig()) + .setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders())) + .build(); + extractedFieldsDetectorFactory.createFromSource( + config, + ActionListener.wrap( + extractedFieldsDetector -> explain(task, config, extractedFieldsDetector, listener), + listener::onFailure + ) + ); + }); + } else { + extractedFieldsDetectorFactory.createFromSource( + request.getConfig(), + ActionListener.wrap( + extractedFieldsDetector -> explain(task, request.getConfig(), extractedFieldsDetector, listener), + listener::onFailure + ) + ); + } + } - private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector, + private void explain(Task task, DataFrameAnalyticsConfig config, ExtractedFieldsDetector extractedFieldsDetector, ActionListener listener) { Tuple> fieldExtraction = extractedFieldsDetector.detect(); @@ -103,7 +140,7 @@ public class TransportExplainDataFrameAnalyticsAction listener::onFailure ); - estimateMemoryUsage(task, request, fieldExtraction.v1(), memoryEstimationListener); + estimateMemoryUsage(task, config, fieldExtraction.v1(), memoryEstimationListener); } /** @@ -112,15 +149,15 @@ public class TransportExplainDataFrameAnalyticsAction * the ML node. */ private void estimateMemoryUsage(Task task, - PutDataFrameAnalyticsAction.Request request, + DataFrameAnalyticsConfig config, ExtractedFields extractedFields, ActionListener listener) { final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId(); DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices( - new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields); + new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, config, extractedFields); processManager.runJobAsync( estimateMemoryTaskId, - request.getConfig(), + config, extractorFactory, ActionListener.wrap( result -> listener.onResponse(