Extract ClientHelper.filterSecurityHeaders method and use it in ML code (#58447) (#58459)

This commit is contained in:
Przemysław Witek 2020-06-23 22:18:39 +02:00 committed by GitHub
parent b40c27698f
commit 4e4ca6ac25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 97 additions and 57 deletions

View File

@ -5,11 +5,11 @@
*/ */
package org.elasticsearch.xpack.core; package org.elasticsearch.xpack.core;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient; import org.elasticsearch.client.OriginSettingClient;
@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField; import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -33,9 +34,23 @@ public final class ClientHelper {
/** /**
* List of headers that are related to security * List of headers that are related to security
*/ */
public static final Set<String> SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationServiceField.RUN_AS_USER_HEADER, public static final Set<String> SECURITY_HEADER_FILTERS =
Sets.newHashSet(
AuthenticationServiceField.RUN_AS_USER_HEADER,
AuthenticationField.AUTHENTICATION_KEY); AuthenticationField.AUTHENTICATION_KEY);
/**
* Leaves only headers that are related to security and filters out the rest.
*
* @param headers Headers to be filtered
* @return A portion of entries that are related to security
*/
public static Map<String, String> filterSecurityHeaders(Map<String, String> headers) {
return Objects.requireNonNull(headers).entrySet().stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
/** /**
* . * .
* @deprecated use ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME * @deprecated use ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME
@ -109,8 +124,7 @@ public final class ClientHelper {
*/ */
public static <T extends ActionResponse> T executeWithHeaders(Map<String, String> headers, String origin, Client client, public static <T extends ActionResponse> T executeWithHeaders(Map<String, String> headers, String origin, Client client,
Supplier<T> supplier) { Supplier<T> supplier) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey())) Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
// no security headers, we will have to use the xpack internal user for // no security headers, we will have to use the xpack internal user for
// our execution by specifying the origin // our execution by specifying the origin
@ -145,8 +159,7 @@ public final class ClientHelper {
void executeWithHeadersAsync(Map<String, String> headers, String origin, Client client, ActionType<Response> action, Request request, void executeWithHeadersAsync(Map<String, String> headers, String origin, Client client, ActionType<Response> action, Request request,
ActionListener<Response> listener) { ActionListener<Response> listener) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey())) Map<String, String> filteredHeaders = filterSecurityHeaders(headers);
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
final ThreadContext threadContext = client.threadPool().getThreadContext(); final ThreadContext threadContext = client.threadPool().getThreadContext();

View File

@ -22,7 +22,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator;
@ -42,7 +41,8 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.SortedMap; import java.util.SortedMap;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
public class MlMetadata implements XPackPlugin.XPackMetadataCustom { public class MlMetadata implements XPackPlugin.XPackMetadataCustom {
@ -315,12 +315,9 @@ public class MlMetadata implements XPackPlugin.XPackMetadataCustom {
if (headers.isEmpty() == false) { if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context // Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig); datafeedConfig = new DatafeedConfig.Builder(datafeedConfig)
Map<String, String> securityHeaders = headers.entrySet().stream() .setHeaders(filterSecurityHeaders(headers))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) .build();
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
datafeedConfig = builder.build();
} }
datafeeds.put(datafeedConfig.getId(), datafeedConfig); datafeeds.put(datafeedConfig.getId(), datafeedConfig);

View File

@ -659,98 +659,113 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
this.indicesOptions = config.indicesOptions; this.indicesOptions = config.indicesOptions;
} }
public void setId(String datafeedId) { public Builder setId(String datafeedId) {
id = ExceptionsHelper.requireNonNull(datafeedId, ID.getPreferredName()); id = ExceptionsHelper.requireNonNull(datafeedId, ID.getPreferredName());
return this;
} }
public String getId() { public String getId() {
return id; return id;
} }
public void setJobId(String jobId) { public Builder setJobId(String jobId) {
this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName()); this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName());
return this;
} }
public void setHeaders(Map<String, String> headers) { public Builder setHeaders(Map<String, String> headers) {
this.headers = ExceptionsHelper.requireNonNull(headers, HEADERS.getPreferredName()); this.headers = ExceptionsHelper.requireNonNull(headers, HEADERS.getPreferredName());
return this;
} }
public void setIndices(List<String> indices) { public Builder setIndices(List<String> indices) {
this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName()); this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName());
return this;
} }
public void setQueryDelay(TimeValue queryDelay) { public Builder setQueryDelay(TimeValue queryDelay) {
TimeUtils.checkNonNegativeMultiple(queryDelay, TimeUnit.MILLISECONDS, QUERY_DELAY); TimeUtils.checkNonNegativeMultiple(queryDelay, TimeUnit.MILLISECONDS, QUERY_DELAY);
this.queryDelay = queryDelay; this.queryDelay = queryDelay;
return this;
} }
public void setFrequency(TimeValue frequency) { public Builder setFrequency(TimeValue frequency) {
TimeUtils.checkPositiveMultiple(frequency, TimeUnit.SECONDS, FREQUENCY); TimeUtils.checkPositiveMultiple(frequency, TimeUnit.SECONDS, FREQUENCY);
this.frequency = frequency; this.frequency = frequency;
return this;
} }
public void setQueryProvider(QueryProvider queryProvider) { public Builder setQueryProvider(QueryProvider queryProvider) {
this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName()); this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName());
return this;
} }
// For testing only // For testing only
public void setParsedQuery(QueryBuilder queryBuilder) { public Builder setParsedQuery(QueryBuilder queryBuilder) {
try { try {
this.queryProvider = ExceptionsHelper.requireNonNull(QueryProvider.fromParsedQuery(queryBuilder), QUERY.getPreferredName()); this.queryProvider = ExceptionsHelper.requireNonNull(QueryProvider.fromParsedQuery(queryBuilder), QUERY.getPreferredName());
} catch (IOException exception) { } catch (IOException exception) {
// eat exception as it should never happen // eat exception as it should never happen
logger.error("Exception trying to setParsedQuery", exception); logger.error("Exception trying to setParsedQuery", exception);
} }
return this;
} }
// For testing only // For testing only
public void setParsedAggregations(AggregatorFactories.Builder aggregations) { public Builder setParsedAggregations(AggregatorFactories.Builder aggregations) {
try { try {
this.aggProvider = AggProvider.fromParsedAggs(aggregations); this.aggProvider = AggProvider.fromParsedAggs(aggregations);
} catch (IOException exception) { } catch (IOException exception) {
// eat exception as it should never happen // eat exception as it should never happen
logger.error("Exception trying to setParsedAggregations", exception); logger.error("Exception trying to setParsedAggregations", exception);
} }
return this;
} }
private void setAggregationsSafe(AggProvider aggProvider) { private Builder setAggregationsSafe(AggProvider aggProvider) {
if (this.aggProvider != null) { if (this.aggProvider != null) {
throw ExceptionsHelper.badRequestException("Found two aggregation definitions: [aggs] and [aggregations]"); throw ExceptionsHelper.badRequestException("Found two aggregation definitions: [aggs] and [aggregations]");
} }
this.aggProvider = aggProvider; this.aggProvider = aggProvider;
return this;
} }
public void setAggProvider(AggProvider aggProvider) { public Builder setAggProvider(AggProvider aggProvider) {
this.aggProvider = aggProvider; this.aggProvider = aggProvider;
return this;
} }
public void setScriptFields(List<SearchSourceBuilder.ScriptField> scriptFields) { public Builder setScriptFields(List<SearchSourceBuilder.ScriptField> scriptFields) {
List<SearchSourceBuilder.ScriptField> sorted = new ArrayList<>(); List<SearchSourceBuilder.ScriptField> sorted = new ArrayList<>();
for (SearchSourceBuilder.ScriptField scriptField : scriptFields) { for (SearchSourceBuilder.ScriptField scriptField : scriptFields) {
sorted.add(scriptField); sorted.add(scriptField);
} }
sorted.sort(Comparator.comparing(SearchSourceBuilder.ScriptField::fieldName)); sorted.sort(Comparator.comparing(SearchSourceBuilder.ScriptField::fieldName));
this.scriptFields = sorted; this.scriptFields = sorted;
return this;
} }
public void setScrollSize(int scrollSize) { public Builder setScrollSize(int scrollSize) {
if (scrollSize < 0) { if (scrollSize < 0) {
String msg = Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, String msg = Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE,
DatafeedConfig.SCROLL_SIZE.getPreferredName(), scrollSize); DatafeedConfig.SCROLL_SIZE.getPreferredName(), scrollSize);
throw ExceptionsHelper.badRequestException(msg); throw ExceptionsHelper.badRequestException(msg);
} }
this.scrollSize = scrollSize; this.scrollSize = scrollSize;
return this;
} }
public void setChunkingConfig(ChunkingConfig chunkingConfig) { public Builder setChunkingConfig(ChunkingConfig chunkingConfig) {
this.chunkingConfig = chunkingConfig; this.chunkingConfig = chunkingConfig;
return this;
} }
public void setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) { public Builder setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) {
this.delayedDataCheckConfig = delayedDataCheckConfig; this.delayedDataCheckConfig = delayedDataCheckConfig;
return this;
} }
public void setMaxEmptySearches(int maxEmptySearches) { public Builder setMaxEmptySearches(int maxEmptySearches) {
if (maxEmptySearches == -1) { if (maxEmptySearches == -1) {
this.maxEmptySearches = null; this.maxEmptySearches = null;
} else if (maxEmptySearches <= 0) { } else if (maxEmptySearches <= 0) {
@ -760,6 +775,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
} else { } else {
this.maxEmptySearches = maxEmptySearches; this.maxEmptySearches = maxEmptySearches;
} }
return this;
} }
public Builder setIndicesOptions(IndicesOptions indicesOptions) { public Builder setIndicesOptions(IndicesOptions indicesOptions) {

View File

@ -24,7 +24,6 @@ import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@ -38,7 +37,8 @@ import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
/** /**
@ -389,10 +389,7 @@ public class DatafeedUpdate implements Writeable, ToXContentObject {
if (headers.isEmpty() == false) { if (headers.isEmpty() == false) {
// Adjust the request, adding security headers from the current thread context // Adjust the request, adding security headers from the current thread context
Map<String, String> securityHeaders = headers.entrySet().stream() builder.setHeaders(filterSecurityHeaders(headers));
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
} }
return builder.build(); return builder.build();

View File

@ -31,6 +31,8 @@ import java.util.concurrent.CountDownLatch;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.elasticsearch.xpack.core.ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME; import static org.elasticsearch.xpack.core.ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.hasSize;
@ -311,4 +313,30 @@ public class ClientHelperTests extends ESTestCase {
return client.search(new SearchRequest()).actionGet(); return client.search(new SearchRequest()).actionGet();
}); });
} }
public void testFilterSecurityHeaders() {
{ // Empty map
assertThat(ClientHelper.filterSecurityHeaders(Collections.emptyMap()), is(anEmptyMap()));
}
{ // Singleton map with no security-related headers
assertThat(ClientHelper.filterSecurityHeaders(Collections.singletonMap("non-security-header", "value")), is(anEmptyMap()));
}
{ // Singleton map with a security-related header
assertThat(
ClientHelper.filterSecurityHeaders(Collections.singletonMap(AuthenticationServiceField.RUN_AS_USER_HEADER, "value")),
hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value"));
}
{ // Map with 3 headers out of which only 1 is security-related
Map<String, String> headers = new HashMap<>();
headers.put("non-security-header-1", "value-1");
headers.put(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2");
headers.put("other-non-security-header", "value-3");
Map<String, String> filteredHeaders = ClientHelper.filterSecurityHeaders(headers);
assertThat(filteredHeaders, is(aMapWithSize(1)));
assertThat(filteredHeaders, hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2"));
}
{ // null
expectThrows(NullPointerException.class, () -> ClientHelper.filterSecurityHeaders(null));
}
}
} }

View File

@ -16,7 +16,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction;
import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig;
@ -33,10 +32,10 @@ import java.io.BufferedReader;
import java.io.InputStream; import java.io.InputStream;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable; import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable;
public class TransportPreviewDatafeedAction extends HandledTransportAction<PreviewDatafeedAction.Request, PreviewDatafeedAction.Response> { public class TransportPreviewDatafeedAction extends HandledTransportAction<PreviewDatafeedAction.Request, PreviewDatafeedAction.Response> {
@ -74,10 +73,7 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction<Previ
jobBuilder -> { jobBuilder -> {
DatafeedConfig.Builder previewDatafeed = buildPreviewDatafeed(datafeedConfig); DatafeedConfig.Builder previewDatafeed = buildPreviewDatafeed(datafeedConfig);
useSecondaryAuthIfAvailable(securityContext, () -> { useSecondaryAuthIfAvailable(securityContext, () -> {
Map<String, String> headers = threadPool.getThreadContext().getHeaders().entrySet().stream() previewDatafeed.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders()));
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
previewDatafeed.setHeaders(headers);
jobResultsProvider.datafeedTimingStats( jobResultsProvider.datafeedTimingStats(
jobBuilder.getId(), jobBuilder.getId(),
timingStats -> { timingStats -> {

View File

@ -45,7 +45,6 @@ import org.elasticsearch.index.query.WildcardQueryBuilder;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@ -69,10 +68,10 @@ import java.util.Set;
import java.util.SortedSet; import java.util.SortedSet;
import java.util.TreeSet; import java.util.TreeSet;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
/** /**
* This class implements CRUD operation for the * This class implements CRUD operation for the
@ -108,12 +107,9 @@ public class DatafeedConfigProvider {
if (headers.isEmpty() == false) { if (headers.isEmpty() == false) {
// Filter any values in headers that aren't security fields // Filter any values in headers that aren't security fields
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(config); config = new DatafeedConfig.Builder(config)
Map<String, String> securityHeaders = headers.entrySet().stream() .setHeaders(filterSecurityHeaders(headers))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) .build();
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
config = builder.build();
} }
final String datafeedId = config.getId(); final String datafeedId = config.getId();

View File

@ -30,7 +30,6 @@ import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@ -51,6 +50,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders;
public class DataFrameAnalyticsConfigProvider { public class DataFrameAnalyticsConfigProvider {
@ -73,12 +73,9 @@ public class DataFrameAnalyticsConfigProvider {
if (headers.isEmpty() == false) { if (headers.isEmpty() == false) {
// Filter any values in headers that aren't security fields // Filter any values in headers that aren't security fields
DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(config); config = new DataFrameAnalyticsConfig.Builder(config)
Map<String, String> securityHeaders = headers.entrySet().stream() .setHeaders(filterSecurityHeaders(headers))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) .build();
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(securityHeaders);
config = builder.build();
} }
try (XContentBuilder builder = XContentFactory.jsonBuilder()) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS)); config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS));