diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ClientHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ClientHelper.java index 49a457f0a7f..2832fd634dd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ClientHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ClientHelper.java @@ -5,11 +5,11 @@ */ package org.elasticsearch.xpack.core; -import org.elasticsearch.action.ActionType; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.client.Client; import org.elasticsearch.client.OriginSettingClient; @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationField; import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Supplier; @@ -33,9 +34,23 @@ public final class ClientHelper { /** * List of headers that are related to security */ - public static final Set SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationServiceField.RUN_AS_USER_HEADER, + public static final Set SECURITY_HEADER_FILTERS = + Sets.newHashSet( + AuthenticationServiceField.RUN_AS_USER_HEADER, AuthenticationField.AUTHENTICATION_KEY); + /** + * Leaves only headers that are related to security and filters out the rest. + * + * @param headers Headers to be filtered + * @return A portion of entries that are related to security + */ + public static Map filterSecurityHeaders(Map headers) { + return Objects.requireNonNull(headers).entrySet().stream() + .filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + /** * . * @deprecated use ThreadContext.ACTION_ORIGIN_TRANSIENT_NAME @@ -109,8 +124,7 @@ public final class ClientHelper { */ public static T executeWithHeaders(Map headers, String origin, Client client, Supplier supplier) { - Map filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + Map filteredHeaders = filterSecurityHeaders(headers); // no security headers, we will have to use the xpack internal user for // our execution by specifying the origin @@ -145,8 +159,7 @@ public final class ClientHelper { void executeWithHeadersAsync(Map headers, String origin, Client client, ActionType action, Request request, ActionListener listener) { - Map filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + Map filteredHeaders = filterSecurityHeaders(headers); final ThreadContext threadContext = client.threadPool().getThreadContext(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java index 8784727acdd..f0497f4d872 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java @@ -22,7 +22,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator; @@ -42,7 +41,8 @@ import java.util.Optional; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; -import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; public class MlMetadata implements XPackPlugin.XPackMetadataCustom { @@ -315,12 +315,9 @@ public class MlMetadata implements XPackPlugin.XPackMetadataCustom { if (headers.isEmpty() == false) { // Adjust the request, adding security headers from the current thread context - DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig); - Map securityHeaders = headers.entrySet().stream() - .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - builder.setHeaders(securityHeaders); - datafeedConfig = builder.build(); + datafeedConfig = new DatafeedConfig.Builder(datafeedConfig) + .setHeaders(filterSecurityHeaders(headers)) + .build(); } datafeeds.put(datafeedConfig.getId(), datafeedConfig); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java index b7621ba7581..71547688635 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java @@ -659,98 +659,113 @@ public class DatafeedConfig extends AbstractDiffable implements this.indicesOptions = config.indicesOptions; } - public void setId(String datafeedId) { + public Builder setId(String datafeedId) { id = ExceptionsHelper.requireNonNull(datafeedId, ID.getPreferredName()); + return this; } public String getId() { return id; } - public void setJobId(String jobId) { + public Builder setJobId(String jobId) { this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName()); + return this; } - public void setHeaders(Map headers) { + public Builder setHeaders(Map headers) { this.headers = ExceptionsHelper.requireNonNull(headers, HEADERS.getPreferredName()); + return this; } - public void setIndices(List indices) { + public Builder setIndices(List indices) { this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName()); + return this; } - public void setQueryDelay(TimeValue queryDelay) { + public Builder setQueryDelay(TimeValue queryDelay) { TimeUtils.checkNonNegativeMultiple(queryDelay, TimeUnit.MILLISECONDS, QUERY_DELAY); this.queryDelay = queryDelay; + return this; } - public void setFrequency(TimeValue frequency) { + public Builder setFrequency(TimeValue frequency) { TimeUtils.checkPositiveMultiple(frequency, TimeUnit.SECONDS, FREQUENCY); this.frequency = frequency; + return this; } - public void setQueryProvider(QueryProvider queryProvider) { + public Builder setQueryProvider(QueryProvider queryProvider) { this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName()); + return this; } // For testing only - public void setParsedQuery(QueryBuilder queryBuilder) { + public Builder setParsedQuery(QueryBuilder queryBuilder) { try { this.queryProvider = ExceptionsHelper.requireNonNull(QueryProvider.fromParsedQuery(queryBuilder), QUERY.getPreferredName()); } catch (IOException exception) { // eat exception as it should never happen logger.error("Exception trying to setParsedQuery", exception); } + return this; } // For testing only - public void setParsedAggregations(AggregatorFactories.Builder aggregations) { + public Builder setParsedAggregations(AggregatorFactories.Builder aggregations) { try { this.aggProvider = AggProvider.fromParsedAggs(aggregations); } catch (IOException exception) { // eat exception as it should never happen logger.error("Exception trying to setParsedAggregations", exception); } + return this; } - private void setAggregationsSafe(AggProvider aggProvider) { + private Builder setAggregationsSafe(AggProvider aggProvider) { if (this.aggProvider != null) { throw ExceptionsHelper.badRequestException("Found two aggregation definitions: [aggs] and [aggregations]"); } this.aggProvider = aggProvider; + return this; } - public void setAggProvider(AggProvider aggProvider) { + public Builder setAggProvider(AggProvider aggProvider) { this.aggProvider = aggProvider; + return this; } - public void setScriptFields(List scriptFields) { + public Builder setScriptFields(List scriptFields) { List sorted = new ArrayList<>(); for (SearchSourceBuilder.ScriptField scriptField : scriptFields) { sorted.add(scriptField); } sorted.sort(Comparator.comparing(SearchSourceBuilder.ScriptField::fieldName)); this.scriptFields = sorted; + return this; } - public void setScrollSize(int scrollSize) { + public Builder setScrollSize(int scrollSize) { if (scrollSize < 0) { String msg = Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, DatafeedConfig.SCROLL_SIZE.getPreferredName(), scrollSize); throw ExceptionsHelper.badRequestException(msg); } this.scrollSize = scrollSize; + return this; } - public void setChunkingConfig(ChunkingConfig chunkingConfig) { + public Builder setChunkingConfig(ChunkingConfig chunkingConfig) { this.chunkingConfig = chunkingConfig; + return this; } - public void setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) { + public Builder setDelayedDataCheckConfig(DelayedDataCheckConfig delayedDataCheckConfig) { this.delayedDataCheckConfig = delayedDataCheckConfig; + return this; } - public void setMaxEmptySearches(int maxEmptySearches) { + public Builder setMaxEmptySearches(int maxEmptySearches) { if (maxEmptySearches == -1) { this.maxEmptySearches = null; } else if (maxEmptySearches <= 0) { @@ -760,6 +775,7 @@ public class DatafeedConfig extends AbstractDiffable implements } else { this.maxEmptySearches = maxEmptySearches; } + return this; } public Builder setIndicesOptions(IndicesOptions indicesOptions) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java index 6e47d720aec..5b2f14f4e08 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -38,7 +37,8 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; /** @@ -389,10 +389,7 @@ public class DatafeedUpdate implements Writeable, ToXContentObject { if (headers.isEmpty() == false) { // Adjust the request, adding security headers from the current thread context - Map securityHeaders = headers.entrySet().stream() - .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - builder.setHeaders(securityHeaders); + builder.setHeaders(filterSecurityHeaders(headers)); } return builder.build(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ClientHelperTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ClientHelperTests.java index 9641f097119..233c96b8213 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ClientHelperTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ClientHelperTests.java @@ -31,6 +31,8 @@ import java.util.concurrent.CountDownLatch; import java.util.function.Consumer; import static org.elasticsearch.xpack.core.ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; @@ -311,4 +313,30 @@ public class ClientHelperTests extends ESTestCase { return client.search(new SearchRequest()).actionGet(); }); } + + public void testFilterSecurityHeaders() { + { // Empty map + assertThat(ClientHelper.filterSecurityHeaders(Collections.emptyMap()), is(anEmptyMap())); + } + { // Singleton map with no security-related headers + assertThat(ClientHelper.filterSecurityHeaders(Collections.singletonMap("non-security-header", "value")), is(anEmptyMap())); + } + { // Singleton map with a security-related header + assertThat( + ClientHelper.filterSecurityHeaders(Collections.singletonMap(AuthenticationServiceField.RUN_AS_USER_HEADER, "value")), + hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value")); + } + { // Map with 3 headers out of which only 1 is security-related + Map headers = new HashMap<>(); + headers.put("non-security-header-1", "value-1"); + headers.put(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2"); + headers.put("other-non-security-header", "value-3"); + Map filteredHeaders = ClientHelper.filterSecurityHeaders(headers); + assertThat(filteredHeaders, is(aMapWithSize(1))); + assertThat(filteredHeaders, hasEntry(AuthenticationServiceField.RUN_AS_USER_HEADER, "value-2")); + } + { // null + expectThrows(NullPointerException.class, () -> ClientHelper.filterSecurityHeaders(null)); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDatafeedAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDatafeedAction.java index b90c96bdc52..6befb4808fd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDatafeedAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPreviewDatafeedAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; @@ -33,10 +32,10 @@ import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.nio.charset.StandardCharsets; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; import static org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable; public class TransportPreviewDatafeedAction extends HandledTransportAction { @@ -74,10 +73,7 @@ public class TransportPreviewDatafeedAction extends HandledTransportAction { DatafeedConfig.Builder previewDatafeed = buildPreviewDatafeed(datafeedConfig); useSecondaryAuthIfAvailable(securityContext, () -> { - Map headers = threadPool.getThreadContext().getHeaders().entrySet().stream() - .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - previewDatafeed.setHeaders(headers); + previewDatafeed.setHeaders(filterSecurityHeaders(threadPool.getThreadContext().getHeaders())); jobResultsProvider.datafeedTimingStats( jobBuilder.getId(), timingStats -> { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/persistence/DatafeedConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/persistence/DatafeedConfigProvider.java index f1c51fc2fd2..f77acb0a618 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/persistence/DatafeedConfigProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/persistence/DatafeedConfigProvider.java @@ -45,7 +45,6 @@ import org.elasticsearch.index.query.WildcardQueryBuilder; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; @@ -69,10 +68,10 @@ import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import java.util.function.BiConsumer; -import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; /** * This class implements CRUD operation for the @@ -108,12 +107,9 @@ public class DatafeedConfigProvider { if (headers.isEmpty() == false) { // Filter any values in headers that aren't security fields - DatafeedConfig.Builder builder = new DatafeedConfig.Builder(config); - Map securityHeaders = headers.entrySet().stream() - .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - builder.setHeaders(securityHeaders); - config = builder.build(); + config = new DatafeedConfig.Builder(config) + .setHeaders(filterSecurityHeaders(headers)) + .build(); } final String datafeedId = config.getId(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java index 8b55932a253..b6a389cf5fc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java @@ -30,7 +30,6 @@ import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -51,6 +50,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.xpack.core.ClientHelper.filterSecurityHeaders; public class DataFrameAnalyticsConfigProvider { @@ -73,12 +73,9 @@ public class DataFrameAnalyticsConfigProvider { if (headers.isEmpty() == false) { // Filter any values in headers that aren't security fields - DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(config); - Map securityHeaders = headers.entrySet().stream() - .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - builder.setHeaders(securityHeaders); - config = builder.build(); + config = new DataFrameAnalyticsConfig.Builder(config) + .setHeaders(filterSecurityHeaders(headers)) + .build(); } try (XContentBuilder builder = XContentFactory.jsonBuilder()) { config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS));