We should prefer secondary auth headers when calling _explain
This commit is contained in:
parent
ca68298e89
commit
a72d7cc76a
|
@ -0,0 +1,196 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
package org.elasticsearch.xpack.ml.integration;
|
||||
|
||||
import org.apache.http.util.EntityUtils;
|
||||
import org.elasticsearch.client.Request;
|
||||
import org.elasticsearch.client.RequestOptions;
|
||||
import org.elasticsearch.client.ResponseException;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.test.SecuritySettingsSourceField;
|
||||
import org.elasticsearch.test.rest.ESRestTestCase;
|
||||
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static org.hamcrest.Matchers.containsString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.not;
|
||||
|
||||
public class ExplainDataFrameAnalyticsRestIT extends ESRestTestCase {
|
||||
|
||||
private static String basicAuth(String user) {
|
||||
return UsernamePasswordToken.basicAuthHeaderValue(user, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
|
||||
}
|
||||
|
||||
private static final String SUPER_USER = "x_pack_rest_user";
|
||||
private static final String ML_ADMIN = "ml_admin";
|
||||
private static final String BASIC_AUTH_VALUE_SUPER_USER = basicAuth(SUPER_USER);
|
||||
private static final String AUTH_KEY = "Authorization";
|
||||
private static final String SECONDARY_AUTH_KEY = "es-secondary-authorization";
|
||||
|
||||
private static RequestOptions.Builder addAuthHeader(RequestOptions.Builder builder, String user) {
|
||||
builder.addHeader(AUTH_KEY, basicAuth(user));
|
||||
return builder;
|
||||
}
|
||||
|
||||
private static RequestOptions.Builder addSecondaryAuthHeader(RequestOptions.Builder builder, String user) {
|
||||
builder.addHeader(SECONDARY_AUTH_KEY, basicAuth(user));
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Settings restClientSettings() {
|
||||
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
|
||||
}
|
||||
|
||||
private void setupUser(String user, List<String> roles) throws IOException {
|
||||
String password = new String(SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING.getChars());
|
||||
|
||||
Request request = new Request("PUT", "/_security/user/" + user);
|
||||
request.setJsonEntity("{"
|
||||
+ " \"password\" : \"" + password + "\","
|
||||
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
|
||||
+ "}");
|
||||
client().performRequest(request);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUpData() throws Exception {
|
||||
// This user has admin rights on machine learning, but (importantly for the tests) no rights
|
||||
// on any of the data indexes
|
||||
setupUser(ML_ADMIN, Collections.singletonList("machine_learning_admin"));
|
||||
addAirlineData();
|
||||
}
|
||||
|
||||
private void addAirlineData() throws IOException {
|
||||
StringBuilder bulk = new StringBuilder();
|
||||
|
||||
// Create index with source = enabled, doc_values = enabled, stored = false + multi-field
|
||||
Request createAirlineDataRequest = new Request("PUT", "/airline-data");
|
||||
createAirlineDataRequest.setJsonEntity("{"
|
||||
+ " \"mappings\": {"
|
||||
+ " \"properties\": {"
|
||||
+ " \"time stamp\": { \"type\":\"date\"}," // space in 'time stamp' is intentional
|
||||
+ " \"airline\": {"
|
||||
+ " \"type\":\"keyword\""
|
||||
+ " },"
|
||||
+ " \"responsetime\": { \"type\":\"float\"}"
|
||||
+ " }"
|
||||
+ " }"
|
||||
+ "}");
|
||||
client().performRequest(createAirlineDataRequest);
|
||||
|
||||
bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 1}}\n");
|
||||
bulk.append("{\"time stamp\":\"2016-06-01T00:00:00Z\",\"airline\":\"AAA\",\"responsetime\":135.22}\n");
|
||||
bulk.append("{\"index\": {\"_index\": \"airline-data\", \"_id\": 2}}\n");
|
||||
bulk.append("{\"time stamp\":\"2016-06-01T01:59:00Z\",\"airline\":\"AAA\",\"responsetime\":541.76}\n");
|
||||
|
||||
bulkIndex(bulk.toString());
|
||||
}
|
||||
|
||||
public void testExplain_GivenSecondaryHeadersAndConfig() throws IOException {
|
||||
String config = "{\n" +
|
||||
" \"source\": {\n" +
|
||||
" \"index\": \"airline-data\"\n" +
|
||||
" },\n" +
|
||||
" \"analysis\": {\n" +
|
||||
" \"regression\": {\n" +
|
||||
" \"dependent_variable\": \"responsetime\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
|
||||
{ // Request with secondary headers without perms
|
||||
Request explain = explainRequestViaConfig(config);
|
||||
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||
addAuthHeader(options, SUPER_USER);
|
||||
addSecondaryAuthHeader(options, ML_ADMIN);
|
||||
explain.setOptions(options);
|
||||
// Should throw
|
||||
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
|
||||
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
|
||||
}
|
||||
{ // request with secondary headers with perms
|
||||
Request explain = explainRequestViaConfig(config);
|
||||
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||
addAuthHeader(options, ML_ADMIN);
|
||||
addSecondaryAuthHeader(options, SUPER_USER);
|
||||
explain.setOptions(options);
|
||||
// Should not throw
|
||||
client().performRequest(explain);
|
||||
}
|
||||
}
|
||||
|
||||
public void testExplain_GivenSecondaryHeadersAndPreviouslyStoredConfig() throws IOException {
|
||||
String config = "{\n" +
|
||||
" \"source\": {\n" +
|
||||
" \"index\": \"airline-data\"\n" +
|
||||
" },\n" +
|
||||
" \"dest\": {\n" +
|
||||
" \"index\": \"response_prediction\"\n" +
|
||||
" },\n" +
|
||||
" \"analysis\":\n" +
|
||||
" {\n" +
|
||||
" \"regression\": {\n" +
|
||||
" \"dependent_variable\": \"responsetime\"\n" +
|
||||
" }\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
String configId = "explain_test";
|
||||
|
||||
Request storeConfig = new Request("PUT", "_ml/data_frame/analytics/" + configId);
|
||||
storeConfig.setJsonEntity(config);
|
||||
client().performRequest(storeConfig);
|
||||
|
||||
{ // Request with secondary headers without perms
|
||||
Request explain = explainRequestConfigId(configId);
|
||||
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||
addAuthHeader(options, SUPER_USER);
|
||||
addSecondaryAuthHeader(options, ML_ADMIN);
|
||||
explain.setOptions(options);
|
||||
// Should throw
|
||||
ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(explain));
|
||||
assertThat(ex.getResponse().getStatusLine().getStatusCode(), equalTo(403));
|
||||
}
|
||||
{ // request with secondary headers with perms
|
||||
Request explain = explainRequestConfigId(configId);
|
||||
RequestOptions.Builder options = explain.getOptions().toBuilder();
|
||||
addAuthHeader(options, ML_ADMIN);
|
||||
addSecondaryAuthHeader(options, SUPER_USER);
|
||||
explain.setOptions(options);
|
||||
// Should not throw
|
||||
client().performRequest(explain);
|
||||
}
|
||||
}
|
||||
|
||||
private static Request explainRequestViaConfig(String config) {
|
||||
Request request = new Request("POST", "_ml/data_frame/analytics/_explain");
|
||||
request.setJsonEntity(config);
|
||||
return request;
|
||||
}
|
||||
|
||||
private static Request explainRequestConfigId(String id) {
|
||||
return new Request("POST", "_ml/data_frame/analytics/" + id + "/_explain");
|
||||
}
|
||||
|
||||
private void bulkIndex(String bulk) throws IOException {
|
||||
Request bulkRequest = new Request("POST", "/_bulk");
|
||||
bulkRequest.setJsonEntity(bulk);
|
||||
bulkRequest.addParameter("refresh", "true");
|
||||
bulkRequest.addParameter("pretty", null);
|
||||
String bulkResponse = EntityUtils.toString(client().performRequest(bulkRequest).getEntity());
|
||||
assertThat(bulkResponse, not(containsString("\"errors\": false")));
|
||||
}
|
||||
|
||||
}
|
|
@ -16,16 +16,21 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
|
|||
import org.elasticsearch.cluster.service.ClusterService;
|
||||
import org.elasticsearch.common.collect.Tuple;
|
||||
import org.elasticsearch.common.inject.Inject;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.license.LicenseUtils;
|
||||
import org.elasticsearch.license.XPackLicenseState;
|
||||
import org.elasticsearch.tasks.Task;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.transport.TransportService;
|
||||
import org.elasticsearch.xpack.core.XPackField;
|
||||
import org.elasticsearch.xpack.core.XPackSettings;
|
||||
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
|
||||
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
|
||||
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
|
||||
import org.elasticsearch.xpack.core.security.SecurityContext;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
|
||||
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
|
||||
|
@ -37,6 +42,9 @@ import java.util.List;
|
|||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
|
||||
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
|
||||
|
||||
/**
|
||||
* Provides explanations on aspects of the given data frame analytics spec like memory estimation, field selection, etc.
|
||||
* Redirects to a different node if the current node is *not* an ML node.
|
||||
|
@ -49,6 +57,8 @@ public class TransportExplainDataFrameAnalyticsAction
|
|||
private final ClusterService clusterService;
|
||||
private final NodeClient client;
|
||||
private final MemoryUsageEstimationProcessManager processManager;
|
||||
private final SecurityContext securityContext;
|
||||
private final ThreadPool threadPool;
|
||||
|
||||
@Inject
|
||||
public TransportExplainDataFrameAnalyticsAction(TransportService transportService,
|
||||
|
@ -56,13 +66,19 @@ public class TransportExplainDataFrameAnalyticsAction
|
|||
ClusterService clusterService,
|
||||
NodeClient client,
|
||||
XPackLicenseState licenseState,
|
||||
MemoryUsageEstimationProcessManager processManager) {
|
||||
MemoryUsageEstimationProcessManager processManager,
|
||||
Settings settings,
|
||||
ThreadPool threadPool) {
|
||||
super(ExplainDataFrameAnalyticsAction.NAME, transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new);
|
||||
this.transportService = transportService;
|
||||
this.clusterService = Objects.requireNonNull(clusterService);
|
||||
this.client = Objects.requireNonNull(client);
|
||||
this.licenseState = licenseState;
|
||||
this.processManager = Objects.requireNonNull(processManager);
|
||||
this.threadPool = threadPool;
|
||||
this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ?
|
||||
new SecurityContext(settings, threadPool.getThreadContext()) :
|
||||
null;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -84,17 +100,38 @@ public class TransportExplainDataFrameAnalyticsAction
|
|||
|
||||
private void explain(Task task, PutDataFrameAnalyticsAction.Request request,
|
||||
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
||||
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory =
|
||||
new ExtractedFieldsDetectorFactory(new ParentTaskAssigningClient(client, task.getParentTaskId()));
|
||||
|
||||
final ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(
|
||||
new ParentTaskAssigningClient(client, task.getParentTaskId())
|
||||
);
|
||||
if (licenseState.isSecurityEnabled()) {
|
||||
useSecondaryAuthIfAvailable(this.securityContext, () -> {
|
||||
// Set the auth headers (preferring the secondary headers) to the caller's.
|
||||
// Regardless if the config was previously stored or not.
|
||||
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder(request.getConfig())
|
||||
.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()))
|
||||
.build();
|
||||
extractedFieldsDetectorFactory.createFromSource(
|
||||
config,
|
||||
ActionListener.wrap(
|
||||
extractedFieldsDetector -> explain(task, config, extractedFieldsDetector, listener),
|
||||
listener::onFailure
|
||||
)
|
||||
);
|
||||
});
|
||||
} else {
|
||||
extractedFieldsDetectorFactory.createFromSource(
|
||||
request.getConfig(),
|
||||
ActionListener.wrap(
|
||||
extractedFieldsDetector -> explain(task, request, extractedFieldsDetector, listener),
|
||||
listener::onFailure)
|
||||
extractedFieldsDetector -> explain(task, request.getConfig(), extractedFieldsDetector, listener),
|
||||
listener::onFailure
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector,
|
||||
}
|
||||
|
||||
private void explain(Task task, DataFrameAnalyticsConfig config, ExtractedFieldsDetector extractedFieldsDetector,
|
||||
ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
|
||||
Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
|
||||
|
||||
|
@ -103,7 +140,7 @@ public class TransportExplainDataFrameAnalyticsAction
|
|||
listener::onFailure
|
||||
);
|
||||
|
||||
estimateMemoryUsage(task, request, fieldExtraction.v1(), memoryEstimationListener);
|
||||
estimateMemoryUsage(task, config, fieldExtraction.v1(), memoryEstimationListener);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -112,15 +149,15 @@ public class TransportExplainDataFrameAnalyticsAction
|
|||
* the ML node.
|
||||
*/
|
||||
private void estimateMemoryUsage(Task task,
|
||||
PutDataFrameAnalyticsAction.Request request,
|
||||
DataFrameAnalyticsConfig config,
|
||||
ExtractedFields extractedFields,
|
||||
ActionListener<MemoryEstimation> listener) {
|
||||
final String estimateMemoryTaskId = "memory_usage_estimation_" + task.getId();
|
||||
DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices(
|
||||
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, request.getConfig(), extractedFields);
|
||||
new ParentTaskAssigningClient(client, task.getParentTaskId()), estimateMemoryTaskId, config, extractedFields);
|
||||
processManager.runJobAsync(
|
||||
estimateMemoryTaskId,
|
||||
request.getConfig(),
|
||||
config,
|
||||
extractorFactory,
|
||||
ActionListener.wrap(
|
||||
result -> listener.onResponse(
|
||||
|
|
Loading…
Reference in New Issue