We should prefer secondary auth headers when calling _explain
This commit is contained in:
parent
ca68298e89
commit
a72d7cc76a
|
@ -0,0 +1,196 @@
|
||||||
|
/*
|
||||||
|
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||||
|
* or more contributor license agreements. Licensed under the Elastic License;
|
||||||
|
* you may not use this file except in compliance with the Elastic License.
|
||||||
|
*/
|
||||||
|
package org.elasticsearch.xpack.ml.integration;
|
||||||
|
|
||||||
|
import org.apache.http.util.EntityUtils;
|
||||||
|
import org.elasticsearch.client.Request;
|
||||||
|
import org.elasticsearch.client.RequestOptions;
|
||||||
|
import org.elasticsearch.client.ResponseException;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
|
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||||
|
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||||
|
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||||
|
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static org.hamcrest.Matchers.containsString;
|
||||||
|
import static org.hamcrest.Matchers.equalTo;
|
||||||
|
import static org.hamcrest.Matchers.not;
|
||||||
|
|
||||||
|
public class ExplainDataFrameAnalyticsRestIT extends ESRestTestCase {
|
||||||
|
|
||||||
|
private static String basicAuth(String user) {
|
||||||
|
return UsernamePasswordToken.basicAuthHeaderValue(user, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final String SUPER_USER = "x_pack_rest_user";
|
||||||
|
private static final String ML_ADMIN = "ml_admin";
|
||||||
|
private static final String BASIC_AUTH_VALUE_SUPER_USER = basicAuth(SUPER_USER);
|
||||||
|
private static final String AUTH_KEY = "Authorization";
|
||||||
|
private static final String SECONDARY_AUTH_KEY = "es-secondary-authorization";
|
||||||
|
|
||||||
|
private static RequestOptions.Builder addAuthHeader(RequestOptions.Builder builder, String user) {
|
||||||
|
builder.addHeader(AUTH_KEY, basicAuth(user));
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static RequestOptions.Builder addSecondaryAuthHeader(RequestOptions.Builder builder, String user) {
|
||||||
|
builder.addHeader(SECONDARY_AUTH_KEY, basicAuth(user));
|
||||||
|
return builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected Settings restClientSettings() {
|
||||||
|
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setupUser(String user, List<String> roles) throws IOException {
|
||||||
|
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());
|
||||||
|
|
||||||
|
Request request = new Request("PUT", "/_security/user/" + user);
|
||||||
|
request.setJsonEntity("{"
|
||||||
|
+ " \"password\" : \"" + password + "\","
|
||||||
|
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
|
||||||
|
+ "}");
|
||||||
|
client().performRequest(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setUpData() throws Exception {
|
||||||
|
// This user has admin rights on machine learning, but (importantly for the tests) no rights
|
||||||
|
// on any of the data indexes
|
||||||
|
setupUser(ML_ADMIN, Collections.singletonList("machine_learning_admin"));
|
||||||
|
addAirlineData();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addAirlineData() throws IOException {
|
||||||
|
StringBuilder bulk = new StringBuilder();
|
||||||
|
|
||||||
|
// Create index with source = enabled, doc_values = enabled, stored = false + multi-field
|
||||||
|
Request createAirlineDataRequest = new Request("PUT", "/airline-data");
|
||||||
|
createAirlineDataRequest.setJsonEntity("{"
|
||||||
|
+ " \"mappings\": {"
|
||||||
|
+ " \"properties\": {"
|
||||||
|
+ " \"time stamp\": { \"type\":\"date\"}," // space in 'time stamp' is intentional
|
||||||
|
+ " \"airline\": {"
|
||||||
|
+ " \"type\":\"keyword\""
|
||||||
|
+ " },"
|
||||||
|
+ " \"responsetime\": { \"type\":\"float\"}"
|
||||||
|
+ " }"
|
||||||
|
+ " }"
|
||||||
|
+ "}");
|
||||||
|
client().performRequest(createAirlineDataRequest);
|
||||||
|
|
||||||
|
bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 1}}\n");
|
||||||
|
bulk.append("{\"time stamp\":\"2016-06-01T00:00:00Z\",\"airline\":\"AAA\",\"responsetime\":135.22}\n");
|
||||||
|
bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 2}}\n");
|
||||||
|
bulk.append("{\"time stamp\":\"2016-06-01T01:59:00Z\",\"airline\":\"AAA\",\"responsetime\":541.76}\n");
|
||||||
|
|
||||||
|
bulkIndex(bulk.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExplain_GivenSecondaryHeadersAndConfig() throws IOException {
|
||||||
|
String config = "{\n" +
|
||||||
|
" \"source\": {\n" +
|
||||||
|
" \"index\": \"airline-data\"\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"analysis\": {\n" +
|
||||||
|
" \"regression\": {\n" +
|
||||||
|
" \"dependent_variable\": \"responsetime\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}";
|
||||||
|
|
||||||
|
|
||||||
|
{ // Request with secondary headers without perms
|
||||||
|
Request explain = explainRequestViaConfig(config);
|
||||||
|
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||||
|
addAuthHeader(options, SUPER_USER);
|
||||||
|
addSecondaryAuthHeader(options, ML_ADMIN);
|
||||||
|
explain.setOptions(options);
|
||||||
|
// Should throw
|
||||||
|
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
|
||||||
|
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
|
||||||
|
}
|
||||||
|
{ // request with secondary headers with perms
|
||||||
|
Request explain = explainRequestViaConfig(config);
|
||||||
|
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||||
|
addAuthHeader(options, ML_ADMIN);
|
||||||
|
addSecondaryAuthHeader(options, SUPER_USER);
|
||||||
|
explain.setOptions(options);
|
||||||
|
// Should not throw
|
||||||
|
client().performRequest(explain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testExplain_GivenSecondaryHeadersAndPreviouslyStoredConfig() throws IOException {
|
||||||
|
String config = "{\n" +
|
||||||
|
" \"source\": {\n" +
|
||||||
|
" \"index\": \"airline-data\"\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"dest\": {\n" +
|
||||||
|
" \"index\": \"response_prediction\"\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"analysis\":\n" +
|
||||||
|
" {\n" +
|
||||||
|
" \"regression\": {\n" +
|
||||||
|
" \"dependent_variable\": \"responsetime\"\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }\n" +
|
||||||
|
"}";
|
||||||
|
|
||||||
|
String configId = "explain_test";
|
||||||
|
|
||||||
|
Request storeConfig = new Request("PUT", "_ml/data_frame/analytics/" + configId);
|
||||||
|
storeConfig.setJsonEntity(config);
|
||||||
|
client().performRequest(storeConfig);
|
||||||
|
|
||||||
|
{ // Request with secondary headers without perms
|
||||||
|
Request explain = explainRequestConfigId(configId);
|
||||||
|
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||||
|
addAuthHeader(options, SUPER_USER);
|
||||||
|
addSecondaryAuthHeader(options, ML_ADMIN);
|
||||||
|
explain.setOptions(options);
|
||||||
|
// Should throw
|
||||||
|
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
|
||||||
|
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
|
||||||
|
}
|
||||||
|
{ // request with secondary headers with perms
|
||||||
|
Request explain = explainRequestConfigId(configId);
|
||||||
|
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||||
|
addAuthHeader(options, ML_ADMIN);
|
||||||
|
addSecondaryAuthHeader(options, SUPER_USER);
|
||||||
|
explain.setOptions(options);
|
||||||
|
// Should not throw
|
||||||
|
client().performRequest(explain);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Request explainRequestViaConfig(String config) {
|
||||||
|
Request request = new Request("POST", "_ml/data_frame/analytics/_explain");
|
||||||
|
request.setJsonEntity(config);
|
||||||
|
return request;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Request explainRequestConfigId(String id) {
|
||||||
|
return new Request("POST", "_ml/data_frame/analytics/" + id + "/_explain");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void bulkIndex(String bulk) throws IOException {
|
||||||
|
Request bulkRequest = new Request("POST", "/_bulk");
|
||||||
|
bulkRequest.setJsonEntity(bulk);
|
||||||
|
bulkRequest.addParameter("refresh", "true");
|
||||||
|
bulkRequest.addParameter("pretty", null);
|
||||||
|
String bulkResponse = EntityUtils.toString(client().performRequest(bulkRequest).getEntity());
|
||||||
|
assertThat(bulkResponse, not(containsString("\"errors\": false")));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -16,16 +16,21 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
||||||
import org.elasticsearch.cluster.service.ClusterService;
|
import org.elasticsearch.cluster.service.ClusterService;
|
||||||
import org.elasticsearch.common.collect.Tuple;
|
import org.elasticsearch.common.collect.Tuple;
|
||||||
import org.elasticsearch.common.inject.Inject;
|
import org.elasticsearch.common.inject.Inject;
|
||||||
|
import org.elasticsearch.common.settings.Settings;
|
||||||
import org.elasticsearch.license.LicenseUtils;
|
import org.elasticsearch.license.LicenseUtils;
|
||||||
import org.elasticsearch.license.XPackLicenseState;
|
import org.elasticsearch.license.XPackLicenseState;
|
||||||
import org.elasticsearch.tasks.Task;
|
import org.elasticsearch.tasks.Task;
|
||||||
|
import org.elasticsearch.threadpool.ThreadPool;
|
||||||
import org.elasticsearch.transport.TransportService;
|
import org.elasticsearch.transport.TransportService;
|
||||||
import org.elasticsearch.xpack.core.XPackField;
|
import org.elasticsearch.xpack.core.XPackField;
|
||||||
|
import org.elasticsearch.xpack.core.XPackSettings;
|
||||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||||
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
||||||
|
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
|
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
|
||||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||||
|
import org.elasticsearch.xpack.core.security.SecurityContext;
|
||||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||||
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
|
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
|
||||||
|
@ -37,6 +42,9 @@ import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
|
||||||
|
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides explanations on aspects of the given data frame analytics spec like memory estimation, field selection, etc.
|
* Provides explanations on aspects of the given data frame analytics spec like memory estimation, field selection, etc.
|
||||||
* Redirects to a different node if the current node is *not* an ML node.
|
* Redirects to a different node if the current node is *not* an ML node.
|
||||||
|
@ -49,6 +57,8 @@ public class TransportExplainDataFrameAnalyticsAction
|
||||||
private final ClusterService clusterService;
|
private final ClusterService clusterService;
|
||||||
private final NodeClient client;
|
private final NodeClient client;
|
||||||
private final MemoryUsageEstimationProcessManager processManager;
|
private final MemoryUsageEstimationProcessManager processManager;
|
||||||
|
private final SecurityContext securityContext;
|
||||||
|
private final ThreadPool threadPool;
|
||||||
|
|
||||||
@Inject
|
@Inject
|
||||||
public TransportExplainDataFrameAnalyticsAction(TransportService transportService,
|
public TransportExplainDataFrameAnalyticsAction(TransportService transportService,
|
||||||
|
@ -56,13 +66,19 @@ public class TransportExplainDataFrameAnalyticsAction
|
||||||
ClusterService clusterService,
|
ClusterService clusterService,
|
||||||
NodeClient client,
|
NodeClient client,
|
||||||
XPackLicenseState licenseState,
|
XPackLicenseState licenseState,
|
||||||
MemoryUsageEstimationProcessManager processManager) {
|
MemoryUsageEstimationProcessManager processManager,
|
||||||
|
Settings settings,
|
||||||
|
ThreadPool threadPool) {
|
||||||
super(ExplainDataFrameAnalyticsAction.NAME, transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new);
|
super(ExplainDataFrameAnalyticsAction.NAME, transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new);
|
||||||
this.transportService = transportService;
|
this.transportService = transportService;
|
||||||
this.clusterService = Objects.requireNonNull(clusterService);
|
this.clusterService = Objects.requireNonNull(clusterService);
|
||||||
this.client = Objects.requireNonNull(client);
|
this.client = Objects.requireNonNull(client);
|
||||||
this.licenseState = licenseState;
|
this.licenseState = licenseState;
|
||||||
this.processManager = Objects.requireNonNull(processManager);
|
this.processManager = Objects.requireNonNull(processManager);
|
||||||
|
this.threadPool = threadPool;
|
||||||
|
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
|
||||||
|
new SecurityContext(settings, threadPool.getThreadContext()) :
|
||||||
|
null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -84,17 +100,38 @@ public class TransportExplainDataFrameAnalyticsAction
|
||||||
|
|
||||||
private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
|
private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
|
||||||
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
||||||
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory =
|
|
||||||
new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId()));
|
final ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(
|
||||||
extractedFieldsDetectorFactory.createFromSource(
|
new ParentTaskAssigningClient(client, task.getParentTaskId())
|
||||||
request.getConfig(),
|
|
||||||
ActionListener.wrap(
|
|
||||||
extractedFieldsDetector -> explain(task, request, extractedFieldsDetector, listener),
|
|
||||||
listener::onFailure)
|
|
||||||
);
|
);
|
||||||
|
if (licenseState.isSecurityEnabled()) {
|
||||||
|
useSecondaryAuthIfAvailable(this.securityContext, () -> {
|
||||||
|
// Set the auth headers (preferring the secondary headers) to the caller's.
|
||||||
|
// Regardless if the config was previously stored or not.
|
||||||
|
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder(request.getConfig())
|
||||||
|
.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()))
|
||||||
|
.build();
|
||||||
|
extractedFieldsDetectorFactory.createFromSource(
|
||||||
|
config,
|
||||||
|
ActionListener.wrap(
|
||||||
|
extractedFieldsDetector -> explain(task, config, extractedFieldsDetector, listener),
|
||||||
|
listener::onFailure
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
extractedFieldsDetectorFactory.createFromSource(
|
||||||
|
request.getConfig(),
|
||||||
|
ActionListener.wrap(
|
||||||
|
extractedFieldsDetector -> explain(task, request.getConfig(), extractedFieldsDetector, listener),
|
||||||
|
listener::onFailure
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector,
|
private void explain(Task task, DataFrameAnalyticsConfig config, ExtractedFieldsDetector extractedFieldsDetector,
|
||||||
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
||||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||||
|
|
||||||
|
@ -103,7 +140,7 @@ public class TransportExplainDataFrameAnalyticsAction
|
||||||
listener::onFailure
|
listener::onFailure
|
||||||
);
|
);
|
||||||
|
|
||||||
estimateMemoryUsage(task, request, fieldExtraction.v1(), memoryEstimationListener);
|
estimateMemoryUsage(task, config, fieldExtraction.v1(), memoryEstimationListener);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -112,15 +149,15 @@ public class TransportExplainDataFrameAnalyticsAction
|
||||||
* the ML node.
|
* the ML node.
|
||||||
*/
|
*/
|
||||||
private void estimateMemoryUsage(Task task,
|
private void estimateMemoryUsage(Task task,
|
||||||
PutDataFrameAnalyticsAction.Request request,
|
DataFrameAnalyticsConfig config,
|
||||||
ExtractedFields extractedFields,
|
ExtractedFields extractedFields,
|
||||||
ActionListener<MemoryEstimation> listener) {
|
ActionListener<MemoryEstimation> listener) {
|
||||||
final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
|
final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
|
||||||
DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
|
DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
|
||||||
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields);
|
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, config, extractedFields);
|
||||||
processManager.runJobAsync(
|
processManager.runJobAsync(
|
||||||
estimateMemoryTaskId,
|
estimateMemoryTaskId,
|
||||||
request.getConfig(),
|
config,
|
||||||
extractorFactory,
|
extractorFactory,
|
||||||
ActionListener.wrap(
|
ActionListener.wrap(
|
||||||
result -> listener.onResponse(
|
result -> listener.onResponse(
|
||||||
|
|
Loading…
Reference in New Issue