[ML] Make datafeeds run-as the user who created/updated them (elastic/x-pack-elasticsearch#3254)

This is the ML equivalent of what was done for Watcher in elastic/x-pack-elasticsearch#2808.

For security reasons, ML datafeeds should not run as the _xpack
user.  Instead, they record the security headers from the request
to create/update them, and reuse these when performing the search
to retrieve data for analysis.

Relates elastic/x-pack-elasticsearch#1071

Original commit: elastic/x-pack-elasticsearch@29f85de404
This commit is contained in:
David Roberts 2017-12-11 13:01:16 +00:00 committed by GitHub
parent 6bae4681e2
commit 5fd68959a0
31 changed files with 520 additions and 185 deletions

View File

@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.ClientHelper.stashWithOrigin;
/**
* A helper class for actions which decides if we should run via the _xpack user and set ML as origin
* or if we should use the run_as functionality by setting the correct headers
*/
public class MlClientHelper {
/**
* List of headers that are related to security
*/
public static final Set<String> SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationService.RUN_AS_USER_HEADER,
Authentication.AUTHENTICATION_KEY);
/**
* Execute a client operation and return the response, try to run a datafeed search with least privileges, when headers exist
*
* @param datafeedConfig The config for a datafeed
* @param client The client used to query
* @param supplier The action to run
* @return An instance of the response class
*/
public static <T extends ActionResponse> T execute(DatafeedConfig datafeedConfig, Client client, Supplier<T> supplier) {
return execute(datafeedConfig.getHeaders(), client, supplier);
}
/**
* Execute a client operation and return the response, try to run an action with least privileges, when headers exist
*
* @param headers Request headers, ideally including security headers
* @param client The client used to query
* @param supplier The action to run
* @return An instance of the response class
*/
public static <T extends ActionResponse> T execute(Map<String, String> headers, Client client, Supplier<T> supplier) {
// no headers, we will have to use the xpack internal user for our execution by specifying the ml origin
if (headers == null || headers.isEmpty()) {
try (ThreadContext.StoredContext ignore = stashWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN)) {
return supplier.get();
}
} else {
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashContext()) {
Map<String, String> filteredHeaders = headers.entrySet().stream()
.filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
client.threadPool().getThreadContext().copyHeaders(filteredHeaders.entrySet());
return supplier.get();
}
}
}
}

View File

@ -19,6 +19,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -48,6 +49,7 @@ import java.util.Set;
import java.util.SortedMap; import java.util.SortedMap;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
public class MlMetadata implements MetaData.Custom { public class MlMetadata implements MetaData.Custom {
@ -101,7 +103,7 @@ public class MlMetadata implements MetaData.Custom {
} }
public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds) { public Set<String> expandDatafeedIds(String expression, boolean allowNoDatafeeds) {
return NameResolver.newUnaliased(datafeeds.keySet(), datafeedId -> ExceptionsHelper.missingDatafeedException(datafeedId)) return NameResolver.newUnaliased(datafeeds.keySet(), ExceptionsHelper::missingDatafeedException)
.expand(expression, allowNoDatafeeds); .expand(expression, allowNoDatafeeds);
} }
@ -285,7 +287,7 @@ public class MlMetadata implements MetaData.Custom {
return this; return this;
} }
public Builder putDatafeed(DatafeedConfig datafeedConfig) { public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
if (datafeeds.containsKey(datafeedConfig.getId())) { if (datafeeds.containsKey(datafeedConfig.getId())) {
throw new ResourceAlreadyExistsException("A datafeed with id [" + datafeedConfig.getId() + "] already exists"); throw new ResourceAlreadyExistsException("A datafeed with id [" + datafeedConfig.getId() + "] already exists");
} }
@ -293,6 +295,17 @@ public class MlMetadata implements MetaData.Custom {
checkJobIsAvailableForDatafeed(jobId); checkJobIsAvailableForDatafeed(jobId);
Job job = jobs.get(jobId); Job job = jobs.get(jobId);
DatafeedJobValidator.validate(datafeedConfig, job); DatafeedJobValidator.validate(datafeedConfig, job);
if (threadContext != null) {
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
datafeedConfig = builder.build();
}
datafeeds.put(datafeedConfig.getId(), datafeedConfig); datafeeds.put(datafeedConfig.getId(), datafeedConfig);
return this; return this;
} }
@ -309,7 +322,7 @@ public class MlMetadata implements MetaData.Custom {
} }
} }
public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks) { public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, ThreadContext threadContext) {
String datafeedId = update.getId(); String datafeedId = update.getId();
DatafeedConfig oldDatafeedConfig = datafeeds.get(datafeedId); DatafeedConfig oldDatafeedConfig = datafeeds.get(datafeedId);
if (oldDatafeedConfig == null) { if (oldDatafeedConfig == null) {
@ -317,7 +330,7 @@ public class MlMetadata implements MetaData.Custom {
} }
checkDatafeedIsStopped(() -> Messages.getMessage(Messages.DATAFEED_CANNOT_UPDATE_IN_CURRENT_STATE, datafeedId, checkDatafeedIsStopped(() -> Messages.getMessage(Messages.DATAFEED_CANNOT_UPDATE_IN_CURRENT_STATE, datafeedId,
DatafeedState.STARTED), datafeedId, persistentTasks); DatafeedState.STARTED), datafeedId, persistentTasks);
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig); DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, threadContext);
if (newDatafeedConfig.getJobId().equals(oldDatafeedConfig.getJobId()) == false) { if (newDatafeedConfig.getJobId().equals(oldDatafeedConfig.getJobId()) == false) {
checkJobIsAvailableForDatafeed(newDatafeedConfig.getJobId()); checkJobIsAvailableForDatafeed(newDatafeedConfig.getJobId());
} }
@ -393,14 +406,13 @@ public class MlMetadata implements MetaData.Custom {
putJob(jobBuilder.build(), true); putJob(jobBuilder.build(), true);
} }
public void checkJobHasNoDatafeed(String jobId) { void checkJobHasNoDatafeed(String jobId) {
Optional<DatafeedConfig> datafeed = getDatafeedByJobId(jobId); Optional<DatafeedConfig> datafeed = getDatafeedByJobId(jobId);
if (datafeed.isPresent()) { if (datafeed.isPresent()) {
throw ExceptionsHelper.conflictStatusException("Cannot delete job [" + jobId + "] because datafeed [" throw ExceptionsHelper.conflictStatusException("Cannot delete job [" + jobId + "] because datafeed ["
+ datafeed.get().getId() + "] refers to it"); + datafeed.get().getId() + "] refers to it");
} }
} }
} }
/** /**

View File

@ -49,7 +49,9 @@ import org.elasticsearch.xpack.security.authz.RoleDescriptor;
import org.elasticsearch.xpack.security.support.Exceptions; import org.elasticsearch.xpack.security.support.Exceptions;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutDatafeedAction.Response, PutDatafeedAction.RequestBuilder> { public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutDatafeedAction.Response, PutDatafeedAction.RequestBuilder> {
@ -218,8 +220,7 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
} }
@Override @Override
protected void masterOperation(Request request, ClusterState state, protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
ActionListener<Response> listener) throws Exception {
// If security is enabled only create the datafeed if the user requesting creation has // If security is enabled only create the datafeed if the user requesting creation has
// permission to read the indices the datafeed is going to read from // permission to read the indices the datafeed is going to read from
if (securityEnabled) { if (securityEnabled) {
@ -266,6 +267,7 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
} }
private void putDatafeed(Request request, ActionListener<Response> listener) { private void putDatafeed(Request request, ActionListener<Response> listener) {
clusterService.submitStateUpdateTask( clusterService.submitStateUpdateTask(
"put-datafeed-" + request.getDatafeed().getId(), "put-datafeed-" + request.getDatafeed().getId(),
new AckedClusterStateUpdateTask<Response>(request, listener) { new AckedClusterStateUpdateTask<Response>(request, listener) {
@ -275,13 +277,11 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
if (acknowledged) { if (acknowledged) {
logger.info("Created datafeed [{}]", request.getDatafeed().getId()); logger.info("Created datafeed [{}]", request.getDatafeed().getId());
} }
return new Response(acknowledged, return new Response(acknowledged, request.getDatafeed());
request.getDatafeed());
} }
@Override @Override
public ClusterState execute(ClusterState currentState) public ClusterState execute(ClusterState currentState) {
throws Exception {
return putDatafeed(request, currentState); return putDatafeed(request, currentState);
} }
}); });
@ -290,7 +290,7 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
private ClusterState putDatafeed(Request request, ClusterState clusterState) { private ClusterState putDatafeed(Request request, ClusterState clusterState) {
MlMetadata currentMetadata = clusterState.getMetaData().custom(MlMetadata.TYPE); MlMetadata currentMetadata = clusterState.getMetaData().custom(MlMetadata.TYPE);
MlMetadata newMetadata = new MlMetadata.Builder(currentMetadata) MlMetadata newMetadata = new MlMetadata.Builder(currentMetadata)
.putDatafeed(request.getDatafeed()).build(); .putDatafeed(request.getDatafeed(), threadPool.getThreadContext()).build();
return ClusterState.builder(clusterState).metaData( return ClusterState.builder(clusterState).metaData(
MetaData.builder(clusterState.getMetaData()).putCustom(MlMetadata.TYPE, newMetadata).build()) MetaData.builder(clusterState.getMetaData()).putCustom(MlMetadata.TYPE, newMetadata).build())
.build(); .build();

View File

@ -72,9 +72,6 @@ import java.util.Objects;
import java.util.function.LongSupplier; import java.util.function.LongSupplier;
import java.util.function.Predicate; import java.util.function.Predicate;
import static org.elasticsearch.xpack.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.ClientHelper.clientWithOrigin;
public class StartDatafeedAction public class StartDatafeedAction
extends Action<StartDatafeedAction.Request, StartDatafeedAction.Response, StartDatafeedAction.RequestBuilder> { extends Action<StartDatafeedAction.Request, StartDatafeedAction.Response, StartDatafeedAction.RequestBuilder> {
@ -437,7 +434,7 @@ public class StartDatafeedAction
super(settings, NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, Request::new); super(settings, NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, Request::new);
this.licenseState = licenseState; this.licenseState = licenseState;
this.persistentTasksService = persistentTasksService; this.persistentTasksService = persistentTasksService;
this.client = clientWithOrigin(client, ML_ORIGIN); this.client = client;
} }
@Override @Override
@ -453,7 +450,7 @@ public class StartDatafeedAction
} }
@Override @Override
protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) throws Exception { protected void masterOperation(Request request, ClusterState state, ActionListener<Response> listener) {
DatafeedParams params = request.params; DatafeedParams params = request.params;
if (licenseState.isMachineLearningAllowed()) { if (licenseState.isMachineLearningAllowed()) {
ActionListener<PersistentTask<DatafeedParams>> finalListener = new ActionListener<PersistentTask<DatafeedParams>>() { ActionListener<PersistentTask<DatafeedParams>> finalListener = new ActionListener<PersistentTask<DatafeedParams>>() {

View File

@ -143,8 +143,7 @@ public class UpdateDatafeedAction extends Action<UpdateDatafeedAction.Request, P
} }
@Override @Override
protected void masterOperation(Request request, ClusterState state, ActionListener<PutDatafeedAction.Response> listener) protected void masterOperation(Request request, ClusterState state, ActionListener<PutDatafeedAction.Response> listener) {
throws Exception {
clusterService.submitStateUpdateTask("update-datafeed-" + request.getUpdate().getId(), clusterService.submitStateUpdateTask("update-datafeed-" + request.getUpdate().getId(),
new AckedClusterStateUpdateTask<PutDatafeedAction.Response>(request, listener) { new AckedClusterStateUpdateTask<PutDatafeedAction.Response>(request, listener) {
private volatile DatafeedConfig updatedDatafeed; private volatile DatafeedConfig updatedDatafeed;
@ -164,7 +163,7 @@ public class UpdateDatafeedAction extends Action<UpdateDatafeedAction.Request, P
PersistentTasksCustomMetaData persistentTasks = PersistentTasksCustomMetaData persistentTasks =
currentState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); currentState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE);
MlMetadata newMetadata = new MlMetadata.Builder(currentMetadata) MlMetadata newMetadata = new MlMetadata.Builder(currentMetadata)
.updateDatafeed(update, persistentTasks).build(); .updateDatafeed(update, persistentTasks, threadPool.getThreadContext()).build();
updatedDatafeed = newMetadata.getDatafeed(update.getId()); updatedDatafeed = newMetadata.getDatafeed(update.getId());
return ClusterState.builder(currentState).metaData( return ClusterState.builder(currentState).metaData(
MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMetadata).build()).build(); MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMetadata).build()).build();

View File

@ -31,6 +31,7 @@ import org.elasticsearch.xpack.ml.job.config.Job;
import org.elasticsearch.xpack.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.job.messages.Messages;
import org.elasticsearch.xpack.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.utils.MlStrings; import org.elasticsearch.xpack.ml.utils.MlStrings;
import org.elasticsearch.xpack.ml.utils.ToXContentParams;
import org.elasticsearch.xpack.ml.utils.time.TimeUtils; import org.elasticsearch.xpack.ml.utils.time.TimeUtils;
import java.io.IOException; import java.io.IOException;
@ -81,6 +82,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
public static final ParseField SCRIPT_FIELDS = new ParseField("script_fields"); public static final ParseField SCRIPT_FIELDS = new ParseField("script_fields");
public static final ParseField SOURCE = new ParseField("_source"); public static final ParseField SOURCE = new ParseField("_source");
public static final ParseField CHUNKING_CONFIG = new ParseField("chunking_config"); public static final ParseField CHUNKING_CONFIG = new ParseField("chunking_config");
public static final ParseField HEADERS = new ParseField("headers");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<Builder, Void> METADATA_PARSER = new ObjectParser<>("datafeed_config", true, Builder::new); public static final ObjectParser<Builder, Void> METADATA_PARSER = new ObjectParser<>("datafeed_config", true, Builder::new);
@ -117,6 +119,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
// TODO this is to read former _source field. Remove in v7.0.0 // TODO this is to read former _source field. Remove in v7.0.0
parser.declareBoolean((builder, value) -> {}, SOURCE); parser.declareBoolean((builder, value) -> {}, SOURCE);
parser.declareObject(Builder::setChunkingConfig, ChunkingConfig.PARSERS.get(parserType), CHUNKING_CONFIG); parser.declareObject(Builder::setChunkingConfig, ChunkingConfig.PARSERS.get(parserType), CHUNKING_CONFIG);
parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS);
} }
} }
@ -140,10 +143,11 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
private final List<SearchSourceBuilder.ScriptField> scriptFields; private final List<SearchSourceBuilder.ScriptField> scriptFields;
private final Integer scrollSize; private final Integer scrollSize;
private final ChunkingConfig chunkingConfig; private final ChunkingConfig chunkingConfig;
private final Map<String, String> headers;
private DatafeedConfig(String id, String jobId, TimeValue queryDelay, TimeValue frequency, List<String> indices, List<String> types, private DatafeedConfig(String id, String jobId, TimeValue queryDelay, TimeValue frequency, List<String> indices, List<String> types,
QueryBuilder query, AggregatorFactories.Builder aggregations, List<SearchSourceBuilder.ScriptField> scriptFields, QueryBuilder query, AggregatorFactories.Builder aggregations, List<SearchSourceBuilder.ScriptField> scriptFields,
Integer scrollSize, ChunkingConfig chunkingConfig) { Integer scrollSize, ChunkingConfig chunkingConfig, Map<String, String> headers) {
this.id = id; this.id = id;
this.jobId = jobId; this.jobId = jobId;
this.queryDelay = queryDelay; this.queryDelay = queryDelay;
@ -155,6 +159,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
this.scriptFields = scriptFields; this.scriptFields = scriptFields;
this.scrollSize = scrollSize; this.scrollSize = scrollSize;
this.chunkingConfig = chunkingConfig; this.chunkingConfig = chunkingConfig;
this.headers = Objects.requireNonNull(headers);
} }
public DatafeedConfig(StreamInput in) throws IOException { public DatafeedConfig(StreamInput in) throws IOException {
@ -185,6 +190,11 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
in.readBoolean(); in.readBoolean();
} }
this.chunkingConfig = in.readOptionalWriteable(ChunkingConfig::new); this.chunkingConfig = in.readOptionalWriteable(ChunkingConfig::new);
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
this.headers = in.readMap(StreamInput::readString, StreamInput::readString);
} else {
this.headers = Collections.emptyMap();
}
} }
public String getId() { public String getId() {
@ -245,6 +255,10 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
return chunkingConfig; return chunkingConfig;
} }
public Map<String, String> getHeaders() {
return headers;
}
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(id); out.writeString(id);
@ -277,6 +291,9 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
out.writeBoolean(false); out.writeBoolean(false);
} }
out.writeOptionalWriteable(chunkingConfig); out.writeOptionalWriteable(chunkingConfig);
if (out.getVersion().onOrAfter(Version.V_6_2_0)) {
out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString);
}
} }
@Override @Override
@ -311,6 +328,10 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
if (chunkingConfig != null) { if (chunkingConfig != null) {
builder.field(CHUNKING_CONFIG.getPreferredName(), chunkingConfig); builder.field(CHUNKING_CONFIG.getPreferredName(), chunkingConfig);
} }
if (headers != null && headers.isEmpty() == false
&& params.paramAsBoolean(ToXContentParams.FOR_CLUSTER_STATE, false) == true) {
builder.field(HEADERS.getPreferredName(), headers);
}
return builder; return builder;
} }
@ -341,13 +362,14 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
&& Objects.equals(this.scrollSize, that.scrollSize) && Objects.equals(this.scrollSize, that.scrollSize)
&& Objects.equals(this.aggregations, that.aggregations) && Objects.equals(this.aggregations, that.aggregations)
&& Objects.equals(this.scriptFields, that.scriptFields) && Objects.equals(this.scriptFields, that.scriptFields)
&& Objects.equals(this.chunkingConfig, that.chunkingConfig); && Objects.equals(this.chunkingConfig, that.chunkingConfig)
&& Objects.equals(this.headers, that.headers);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(id, jobId, frequency, queryDelay, indices, types, query, scrollSize, aggregations, scriptFields, return Objects.hash(id, jobId, frequency, queryDelay, indices, types, query, scrollSize, aggregations, scriptFields,
chunkingConfig); chunkingConfig, headers);
} }
@Override @Override
@ -420,6 +442,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
private List<SearchSourceBuilder.ScriptField> scriptFields; private List<SearchSourceBuilder.ScriptField> scriptFields;
private Integer scrollSize = DEFAULT_SCROLL_SIZE; private Integer scrollSize = DEFAULT_SCROLL_SIZE;
private ChunkingConfig chunkingConfig; private ChunkingConfig chunkingConfig;
private Map<String, String> headers = Collections.emptyMap();
public Builder() { public Builder() {
} }
@ -442,6 +465,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
this.scriptFields = config.scriptFields; this.scriptFields = config.scriptFields;
this.scrollSize = config.scrollSize; this.scrollSize = config.scrollSize;
this.chunkingConfig = config.chunkingConfig; this.chunkingConfig = config.chunkingConfig;
this.headers = config.headers;
} }
public void setId(String datafeedId) { public void setId(String datafeedId) {
@ -452,6 +476,10 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName()); this.jobId = ExceptionsHelper.requireNonNull(jobId, Job.ID.getPreferredName());
} }
public void setHeaders(Map<String, String> headers) {
this.headers = headers;
}
public void setIndices(List<String> indices) { public void setIndices(List<String> indices) {
this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName()); this.indices = ExceptionsHelper.requireNonNull(indices, INDICES.getPreferredName());
} }
@ -516,7 +544,7 @@ public class DatafeedConfig extends AbstractDiffable<DatafeedConfig> implements
setDefaultChunkingConfig(); setDefaultChunkingConfig();
setDefaultQueryDelay(); setDefaultQueryDelay();
return new DatafeedConfig(id, jobId, queryDelay, frequency, indices, types, query, aggregations, scriptFields, scrollSize, return new DatafeedConfig(id, jobId, queryDelay, frequency, indices, types, query, aggregations, scriptFields, scrollSize,
chunkingConfig); chunkingConfig, headers);
} }
void validateAggregations() { void validateAggregations() {

View File

@ -25,9 +25,6 @@ import java.util.Objects;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.xpack.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.ClientHelper.clientWithOrigin;
public class DatafeedJobBuilder { public class DatafeedJobBuilder {
private final Client client; private final Client client;
@ -36,7 +33,7 @@ public class DatafeedJobBuilder {
private final Supplier<Long> currentTimeSupplier; private final Supplier<Long> currentTimeSupplier;
public DatafeedJobBuilder(Client client, JobProvider jobProvider, Auditor auditor, Supplier<Long> currentTimeSupplier) { public DatafeedJobBuilder(Client client, JobProvider jobProvider, Auditor auditor, Supplier<Long> currentTimeSupplier) {
this.client = clientWithOrigin(client, ML_ORIGIN); this.client = client;
this.jobProvider = Objects.requireNonNull(jobProvider); this.jobProvider = Objects.requireNonNull(jobProvider);
this.auditor = Objects.requireNonNull(auditor); this.auditor = Objects.requireNonNull(auditor);
this.currentTimeSupplier = Objects.requireNonNull(currentTimeSupplier); this.currentTimeSupplier = Objects.requireNonNull(currentTimeSupplier);

View File

@ -17,6 +17,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MachineLearning;
@ -463,7 +464,16 @@ public class DatafeedManager extends AbstractComponent {
} }
private void runTask(StartDatafeedAction.DatafeedTask task) { private void runTask(StartDatafeedAction.DatafeedTask task) {
innerRun(runningDatafeedsOnThisNode.get(task.getAllocationId()), task.getDatafeedStartTime(), task.getEndTime()); // This clearing of the thread context is not strictly necessary. Every action performed by the
// datafeed _should_ be done using the MlClientHelper, which will set the appropriate thread
// context. However, by clearing the thread context here if anyone forgets to use MlClientHelper
// somewhere else in the datafeed code then it should cause a failure in the same way in single
// and multi node clusters. If we didn't clear the thread context here then there's a risk that
// a context with sufficient permissions would coincidentally be in force in some single node
// tests, leading to bugs not caught in CI due to many tests running in single node test clusters.
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
innerRun(runningDatafeedsOnThisNode.get(task.getAllocationId()), task.getDatafeedStartTime(), task.getEndTime());
}
} }
@Override @Override

View File

@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
@ -20,6 +21,7 @@ import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.ml.MlClientHelper;
import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils;
import org.elasticsearch.xpack.ml.job.config.Job; import org.elasticsearch.xpack.ml.job.config.Job;
import org.elasticsearch.xpack.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.utils.ExceptionsHelper;
@ -29,7 +31,9 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
/** /**
* A datafeed update contains partial properties to update a {@link DatafeedConfig}. * A datafeed update contains partial properties to update a {@link DatafeedConfig}.
@ -260,7 +264,7 @@ public class DatafeedUpdate implements Writeable, ToXContentObject {
* Applies the update to the given {@link DatafeedConfig} * Applies the update to the given {@link DatafeedConfig}
* @return a new {@link DatafeedConfig} that contains the update * @return a new {@link DatafeedConfig} that contains the update
*/ */
public DatafeedConfig apply(DatafeedConfig datafeedConfig) { public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
if (id.equals(datafeedConfig.getId()) == false) { if (id.equals(datafeedConfig.getId()) == false) {
throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id"); throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id");
} }
@ -296,6 +300,15 @@ public class DatafeedUpdate implements Writeable, ToXContentObject {
if (chunkingConfig != null) { if (chunkingConfig != null) {
builder.setChunkingConfig(chunkingConfig); builder.setChunkingConfig(chunkingConfig);
} }
if (threadContext != null) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
}
return builder.build(); return builder.build();
} }

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.xpack.ml.MlClientHelper;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor;
import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils;
@ -111,7 +112,7 @@ class AggregationDataExtractor implements DataExtractor {
} }
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
return searchRequestBuilder.get(); return MlClientHelper.execute(context.headers, client, searchRequestBuilder::get);
} }
private SearchRequestBuilder buildSearchRequest() { private SearchRequestBuilder buildSearchRequest() {

View File

@ -9,6 +9,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactories;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
@ -24,9 +25,11 @@ class AggregationDataExtractorContext {
final long start; final long start;
final long end; final long end;
final boolean includeDocCount; final boolean includeDocCount;
final Map<String, String> headers;
AggregationDataExtractorContext(String jobId, String timeField, Set<String> fields, List<String> indices, List<String> types, AggregationDataExtractorContext(String jobId, String timeField, Set<String> fields, List<String> indices, List<String> types,
QueryBuilder query, AggregatorFactories.Builder aggs, long start, long end, boolean includeDocCount) { QueryBuilder query, AggregatorFactories.Builder aggs, long start, long end, boolean includeDocCount,
Map<String, String> headers) {
this.jobId = Objects.requireNonNull(jobId); this.jobId = Objects.requireNonNull(jobId);
this.timeField = Objects.requireNonNull(timeField); this.timeField = Objects.requireNonNull(timeField);
this.fields = Objects.requireNonNull(fields); this.fields = Objects.requireNonNull(fields);
@ -37,5 +40,6 @@ class AggregationDataExtractorContext {
this.start = start; this.start = start;
this.end = end; this.end = end;
this.includeDocCount = includeDocCount; this.includeDocCount = includeDocCount;
this.headers = headers;
} }
} }

View File

@ -39,7 +39,8 @@ public class AggregationDataExtractorFactory implements DataExtractorFactory {
datafeedConfig.getAggregations(), datafeedConfig.getAggregations(),
Intervals.alignToCeil(start, histogramInterval), Intervals.alignToCeil(start, histogramInterval),
Intervals.alignToFloor(end, histogramInterval), Intervals.alignToFloor(end, histogramInterval),
job.getAnalysisConfig().getSummaryCountFieldName().equals(DatafeedConfig.DOC_COUNT)); job.getAnalysisConfig().getSummaryCountFieldName().equals(DatafeedConfig.DOC_COUNT),
datafeedConfig.getHeaders());
return new AggregationDataExtractor(client, dataExtractorContext); return new AggregationDataExtractor(client, dataExtractorContext);
} }
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations; import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.max.Max; import org.elasticsearch.search.aggregations.metrics.max.Max;
import org.elasticsearch.search.aggregations.metrics.min.Min; import org.elasticsearch.search.aggregations.metrics.min.Min;
import org.elasticsearch.xpack.ml.MlClientHelper;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory;
import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils;
@ -133,7 +134,7 @@ public class ChunkedDataExtractor implements DataExtractor {
} }
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
return searchRequestBuilder.get(); return MlClientHelper.execute(context.headers, client, searchRequestBuilder::get);
} }
private Optional<InputStream> getNextStream() throws IOException { private Optional<InputStream> getNextStream() throws IOException {

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
class ChunkedDataExtractorContext { class ChunkedDataExtractorContext {
@ -29,10 +30,11 @@ class ChunkedDataExtractorContext {
final long end; final long end;
final TimeValue chunkSpan; final TimeValue chunkSpan;
final TimeAligner timeAligner; final TimeAligner timeAligner;
final Map<String, String> headers;
ChunkedDataExtractorContext(String jobId, String timeField, List<String> indices, List<String> types, ChunkedDataExtractorContext(String jobId, String timeField, List<String> indices, List<String> types,
QueryBuilder query, int scrollSize, long start, long end, @Nullable TimeValue chunkSpan, QueryBuilder query, int scrollSize, long start, long end, @Nullable TimeValue chunkSpan,
TimeAligner timeAligner) { TimeAligner timeAligner, Map<String, String> headers) {
this.jobId = Objects.requireNonNull(jobId); this.jobId = Objects.requireNonNull(jobId);
this.timeField = Objects.requireNonNull(timeField); this.timeField = Objects.requireNonNull(timeField);
this.indices = indices.toArray(new String[indices.size()]); this.indices = indices.toArray(new String[indices.size()]);
@ -43,5 +45,6 @@ class ChunkedDataExtractorContext {
this.end = end; this.end = end;
this.chunkSpan = chunkSpan; this.chunkSpan = chunkSpan;
this.timeAligner = Objects.requireNonNull(timeAligner); this.timeAligner = Objects.requireNonNull(timeAligner);
this.headers = headers;
} }
} }

View File

@ -41,7 +41,8 @@ public class ChunkedDataExtractorFactory implements DataExtractorFactory {
timeAligner.alignToCeil(start), timeAligner.alignToCeil(start),
timeAligner.alignToFloor(end), timeAligner.alignToFloor(end),
datafeedConfig.getChunkingConfig().getTimeSpan(), datafeedConfig.getChunkingConfig().getTimeSpan(),
timeAligner); timeAligner,
datafeedConfig.getHeaders());
return new ChunkedDataExtractor(client, dataExtractorFactory, dataExtractorContext); return new ChunkedDataExtractor(client, dataExtractorFactory, dataExtractorContext);
} }

View File

@ -20,6 +20,7 @@ import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.ml.MlClientHelper;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor;
import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.ExtractorUtils;
import org.elasticsearch.xpack.ml.utils.DomainSplitFunction; import org.elasticsearch.xpack.ml.utils.DomainSplitFunction;
@ -98,7 +99,7 @@ class ScrollDataExtractor implements DataExtractor {
} }
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
return searchRequestBuilder.get(); return MlClientHelper.execute(context.headers, client, searchRequestBuilder::get);
} }
private SearchRequestBuilder buildSearchRequest(long start) { private SearchRequestBuilder buildSearchRequest(long start) {
@ -182,7 +183,7 @@ class ScrollDataExtractor implements DataExtractor {
private InputStream continueScroll() throws IOException { private InputStream continueScroll() throws IOException {
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId); LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
SearchResponse searchResponse = null; SearchResponse searchResponse;
try { try {
searchResponse = executeSearchScrollRequest(scrollId); searchResponse = executeSearchScrollRequest(scrollId);
} catch (SearchPhaseExecutionException searchExecutionException) { } catch (SearchPhaseExecutionException searchExecutionException) {
@ -208,10 +209,10 @@ class ScrollDataExtractor implements DataExtractor {
} }
protected SearchResponse executeSearchScrollRequest(String scrollId) { protected SearchResponse executeSearchScrollRequest(String scrollId) {
return SearchScrollAction.INSTANCE.newRequestBuilder(client) return MlClientHelper.execute(context.headers, client, () -> SearchScrollAction.INSTANCE.newRequestBuilder(client)
.setScroll(SCROLL_TIMEOUT) .setScroll(SCROLL_TIMEOUT)
.setScrollId(scrollId) .setScrollId(scrollId)
.get(); .get());
} }
private void resetScroll() { private void resetScroll() {
@ -223,7 +224,7 @@ class ScrollDataExtractor implements DataExtractor {
if (scrollId != null) { if (scrollId != null) {
ClearScrollRequest request = new ClearScrollRequest(); ClearScrollRequest request = new ClearScrollRequest();
request.addScrollId(scrollId); request.addScrollId(scrollId);
client.execute(ClearScrollAction.INSTANCE, request).actionGet(); MlClientHelper.execute(context.headers, client, () -> client.execute(ClearScrollAction.INSTANCE, request).actionGet());
} }
} }
} }

View File

@ -9,6 +9,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
class ScrollDataExtractorContext { class ScrollDataExtractorContext {
@ -22,10 +23,11 @@ class ScrollDataExtractorContext {
final int scrollSize; final int scrollSize;
final long start; final long start;
final long end; final long end;
final Map<String, String> headers;
ScrollDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, List<String> types, ScrollDataExtractorContext(String jobId, ExtractedFields extractedFields, List<String> indices, List<String> types,
QueryBuilder query, List<SearchSourceBuilder.ScriptField> scriptFields, int scrollSize, QueryBuilder query, List<SearchSourceBuilder.ScriptField> scriptFields, int scrollSize,
long start, long end) { long start, long end, Map<String, String> headers) {
this.jobId = Objects.requireNonNull(jobId); this.jobId = Objects.requireNonNull(jobId);
this.extractedFields = Objects.requireNonNull(extractedFields); this.extractedFields = Objects.requireNonNull(extractedFields);
this.indices = indices.toArray(new String[indices.size()]); this.indices = indices.toArray(new String[indices.size()]);
@ -35,5 +37,6 @@ class ScrollDataExtractorContext {
this.scrollSize = scrollSize; this.scrollSize = scrollSize;
this.start = start; this.start = start;
this.end = end; this.end = end;
this.headers = headers;
} }
} }

View File

@ -11,7 +11,6 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest;
import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexNotFoundException;
import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractor;
@ -46,7 +45,8 @@ public class ScrollDataExtractorFactory implements DataExtractorFactory {
datafeedConfig.getScriptFields(), datafeedConfig.getScriptFields(),
datafeedConfig.getScrollSize(), datafeedConfig.getScrollSize(),
start, start,
end); end,
datafeedConfig.getHeaders());
return new ScrollDataExtractor(client, dataExtractorContext); return new ScrollDataExtractor(client, dataExtractorContext);
} }

View File

@ -0,0 +1,117 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.junit.Before;
import java.util.Collections;
import java.util.Map;
import java.util.function.Consumer;
import static org.elasticsearch.xpack.ClientHelper.ACTION_ORIGIN_TRANSIENT_NAME;
import static org.elasticsearch.xpack.ClientHelper.ML_ORIGIN;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class MlClientHelperTests extends ESTestCase {
private Client client = mock(Client.class);
@Before
public void setupMocks() {
ThreadPool threadPool = mock(ThreadPool.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(threadPool.getThreadContext()).thenReturn(threadContext);
when(client.threadPool()).thenReturn(threadPool);
PlainActionFuture<SearchResponse> searchFuture = PlainActionFuture.newFuture();
searchFuture.onResponse(new SearchResponse());
when(client.search(any())).thenReturn(searchFuture);
}
public void testEmptyHeaders() {
DatafeedConfig.Builder builder = new DatafeedConfig.Builder("datafeed-foo", "foo");
builder.setIndices(Collections.singletonList("foo-index"));
assertExecutionWithOrigin(builder.build());
}
public void testWithHeaders() {
DatafeedConfig.Builder builder = new DatafeedConfig.Builder("datafeed-foo", "foo");
builder.setIndices(Collections.singletonList("foo-index"));
Map<String, String> headers = MapBuilder.<String, String>newMapBuilder()
.put(Authentication.AUTHENTICATION_KEY, "anything")
.put(AuthenticationService.RUN_AS_USER_HEADER, "anything")
.map();
builder.setHeaders(headers);
assertRunAsExecution(builder.build(), h -> {
assertThat(h.keySet(), hasSize(2));
assertThat(h, hasEntry(Authentication.AUTHENTICATION_KEY, "anything"));
assertThat(h, hasEntry(AuthenticationService.RUN_AS_USER_HEADER, "anything"));
});
}
public void testFilteredHeaders() {
DatafeedConfig.Builder builder = new DatafeedConfig.Builder("datafeed-foo", "foo");
builder.setIndices(Collections.singletonList("foo-index"));
Map<String, String> unrelatedHeaders = MapBuilder.<String, String>newMapBuilder()
.put(randomAlphaOfLength(10), "anything")
.map();
builder.setHeaders(unrelatedHeaders);
assertRunAsExecution(builder.build(), h -> assertThat(h.keySet(), hasSize(0)));
}
/**
* This method executes a search and checks if the thread context was enriched with the ml origin
*/
private void assertExecutionWithOrigin(DatafeedConfig datafeedConfig) {
MlClientHelper.execute(datafeedConfig, client, () -> {
Object origin = client.threadPool().getThreadContext().getTransient(ACTION_ORIGIN_TRANSIENT_NAME);
assertThat(origin, is(ML_ORIGIN));
// Check that headers are not set
Map<String, String> headers = client.threadPool().getThreadContext().getHeaders();
assertThat(headers, not(hasEntry(Authentication.AUTHENTICATION_KEY, "anything")));
assertThat(headers, not(hasEntry(AuthenticationService.RUN_AS_USER_HEADER, "anything")));
return client.search(new SearchRequest()).actionGet();
});
}
/**
* This method executes a search and ensures no stashed origin thread context was created, so that the regular node
* client was used, to emulate a run_as function
*/
public void assertRunAsExecution(DatafeedConfig datafeedConfig, Consumer<Map<String, String>> consumer) {
MlClientHelper.execute(datafeedConfig, client, () -> {
Object origin = client.threadPool().getThreadContext().getTransient(ACTION_ORIGIN_TRANSIENT_NAME);
assertThat(origin, is(nullValue()));
consumer.accept(client.threadPool().getThreadContext().getHeaders());
return client.search(new SearchRequest()).actionGet();
});
}
}

View File

@ -29,7 +29,6 @@ import org.elasticsearch.xpack.ml.job.config.JobTaskStatus;
import org.elasticsearch.xpack.ml.job.config.JobTests; import org.elasticsearch.xpack.ml.job.config.JobTests;
import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.Map; import java.util.Map;
@ -62,7 +61,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
} }
job = new Job.Builder(job).setAnalysisConfig(analysisConfig).build(); job = new Job.Builder(job).setAnalysisConfig(analysisConfig).build();
builder.putJob(job, false); builder.putJob(job, false);
builder.putDatafeed(datafeedConfig); builder.putDatafeed(datafeedConfig, null);
} else { } else {
builder.putJob(job, false); builder.putJob(job, false);
} }
@ -163,7 +162,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> builder.deleteJob(job1.getId(), new PersistentTasksCustomMetaData(0L, Collections.emptyMap()))); () -> builder.deleteJob(job1.getId(), new PersistentTasksCustomMetaData(0L, Collections.emptyMap())));
@ -183,7 +182,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
MlMetadata result = builder.build(); MlMetadata result = builder.build();
assertThat(result.getJobs().get("job_id"), sameInstance(job1)); assertThat(result.getJobs().get("job_id"), sameInstance(job1));
@ -200,7 +199,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", "missing-job").build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", "missing-job").build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
expectThrows(ResourceNotFoundException.class, () -> builder.putDatafeed(datafeedConfig1)); expectThrows(ResourceNotFoundException.class, () -> builder.putDatafeed(datafeedConfig1, null));
} }
public void testPutDatafeed_failBecauseJobIsBeingDeleted() { public void testPutDatafeed_failBecauseJobIsBeingDeleted() {
@ -209,7 +208,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
expectThrows(ResourceNotFoundException.class, () -> builder.putDatafeed(datafeedConfig1)); expectThrows(ResourceNotFoundException.class, () -> builder.putDatafeed(datafeedConfig1, null));
} }
public void testPutDatafeed_failBecauseDatafeedIdIsAlreadyTaken() { public void testPutDatafeed_failBecauseDatafeedIdIsAlreadyTaken() {
@ -217,9 +216,9 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
expectThrows(ResourceAlreadyExistsException.class, () -> builder.putDatafeed(datafeedConfig1)); expectThrows(ResourceAlreadyExistsException.class, () -> builder.putDatafeed(datafeedConfig1, null));
} }
public void testPutDatafeed_failBecauseJobAlreadyHasDatafeed() { public void testPutDatafeed_failBecauseJobAlreadyHasDatafeed() {
@ -228,10 +227,10 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig2 = createDatafeedConfig("datafeed2", job1.getId()).build(); DatafeedConfig datafeedConfig2 = createDatafeedConfig("datafeed2", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> builder.putDatafeed(datafeedConfig2)); () -> builder.putDatafeed(datafeedConfig2, null));
assertThat(e.status(), equalTo(RestStatus.CONFLICT)); assertThat(e.status(), equalTo(RestStatus.CONFLICT));
} }
@ -245,7 +244,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1.build(now), false); builder.putJob(job1.build(now), false);
expectThrows(ElasticsearchStatusException.class, () -> builder.putDatafeed(datafeedConfig1)); expectThrows(ElasticsearchStatusException.class, () -> builder.putDatafeed(datafeedConfig1, null));
} }
public void testUpdateDatafeed() { public void testUpdateDatafeed() {
@ -253,12 +252,12 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
MlMetadata beforeMetadata = builder.build(); MlMetadata beforeMetadata = builder.build();
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId()); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId());
update.setScrollSize(5000); update.setScrollSize(5000);
MlMetadata updatedMetadata = new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null).build(); MlMetadata updatedMetadata = new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null, null).build();
DatafeedConfig updatedDatafeed = updatedMetadata.getDatafeed(datafeedConfig1.getId()); DatafeedConfig updatedDatafeed = updatedMetadata.getDatafeed(datafeedConfig1.getId());
assertThat(updatedDatafeed.getJobId(), equalTo(datafeedConfig1.getJobId())); assertThat(updatedDatafeed.getJobId(), equalTo(datafeedConfig1.getJobId()));
@ -270,7 +269,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
public void testUpdateDatafeed_failBecauseDatafeedDoesNotExist() { public void testUpdateDatafeed_failBecauseDatafeedDoesNotExist() {
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder("job_id"); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder("job_id");
update.setScrollSize(5000); update.setScrollSize(5000);
expectThrows(ResourceNotFoundException.class, () -> new MlMetadata.Builder().updateDatafeed(update.build(), null).build()); expectThrows(ResourceNotFoundException.class, () -> new MlMetadata.Builder().updateDatafeed(update.build(), null, null).build());
} }
public void testUpdateDatafeed_failBecauseDatafeedIsNotStopped() { public void testUpdateDatafeed_failBecauseDatafeedIsNotStopped() {
@ -278,7 +277,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
MlMetadata beforeMetadata = builder.build(); MlMetadata beforeMetadata = builder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -290,7 +289,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
update.setScrollSize(5000); update.setScrollSize(5000);
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), tasksInProgress)); () -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), tasksInProgress, null));
assertThat(e.status(), equalTo(RestStatus.CONFLICT)); assertThat(e.status(), equalTo(RestStatus.CONFLICT));
} }
@ -299,14 +298,14 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
MlMetadata beforeMetadata = builder.build(); MlMetadata beforeMetadata = builder.build();
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId()); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId());
update.setJobId(job1.getId() + "_2"); update.setJobId(job1.getId() + "_2");
expectThrows(ResourceNotFoundException.class, expectThrows(ResourceNotFoundException.class,
() -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null)); () -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null, null));
} }
public void testUpdateDatafeed_failBecauseNewJobHasAnotherDatafeedAttached() { public void testUpdateDatafeed_failBecauseNewJobHasAnotherDatafeedAttached() {
@ -318,15 +317,15 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putJob(job2.build(), false); builder.putJob(job2.build(), false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
builder.putDatafeed(datafeedConfig2); builder.putDatafeed(datafeedConfig2, null);
MlMetadata beforeMetadata = builder.build(); MlMetadata beforeMetadata = builder.build();
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId()); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeedConfig1.getId());
update.setJobId(job2.getId()); update.setJobId(job2.getId());
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
() -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null)); () -> new MlMetadata.Builder(beforeMetadata).updateDatafeed(update.build(), null, null));
assertThat(e.status(), equalTo(RestStatus.CONFLICT)); assertThat(e.status(), equalTo(RestStatus.CONFLICT));
assertThat(e.getMessage(), equalTo("A datafeed [datafeed2] already exists for job [job_id_2]")); assertThat(e.getMessage(), equalTo("A datafeed [datafeed2] already exists for job [job_id_2]"));
} }
@ -336,7 +335,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build(); DatafeedConfig datafeedConfig1 = createDatafeedConfig("datafeed1", job1.getId()).build();
MlMetadata.Builder builder = new MlMetadata.Builder(); MlMetadata.Builder builder = new MlMetadata.Builder();
builder.putJob(job1, false); builder.putJob(job1, false);
builder.putDatafeed(datafeedConfig1); builder.putDatafeed(datafeedConfig1, null);
MlMetadata result = builder.build(); MlMetadata result = builder.build();
assertThat(result.getJobs().get("job_id"), sameInstance(job1)); assertThat(result.getJobs().get("job_id"), sameInstance(job1));
@ -377,9 +376,9 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
public void testExpandDatafeedIds() { public void testExpandDatafeedIds() {
MlMetadata.Builder mlMetadataBuilder = newMlMetadataWithJobs("bar-1", "foo-1", "foo-2"); MlMetadata.Builder mlMetadataBuilder = newMlMetadataWithJobs("bar-1", "foo-1", "foo-2");
mlMetadataBuilder.putDatafeed(createDatafeedConfig("bar-1-feed", "bar-1").build()); mlMetadataBuilder.putDatafeed(createDatafeedConfig("bar-1-feed", "bar-1").build(), null);
mlMetadataBuilder.putDatafeed(createDatafeedConfig("foo-1-feed", "foo-1").build()); mlMetadataBuilder.putDatafeed(createDatafeedConfig("foo-1-feed", "foo-1").build(), null);
mlMetadataBuilder.putDatafeed(createDatafeedConfig("foo-2-feed", "foo-2").build()); mlMetadataBuilder.putDatafeed(createDatafeedConfig("foo-2-feed", "foo-2").build(), null);
MlMetadata mlMetadata = mlMetadataBuilder.build(); MlMetadata mlMetadata = mlMetadataBuilder.build();
@ -399,7 +398,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
} }
@Override @Override
protected MlMetadata mutateInstance(MlMetadata instance) throws IOException { protected MlMetadata mutateInstance(MlMetadata instance) {
Map<String, Job> jobs = instance.getJobs(); Map<String, Job> jobs = instance.getJobs();
Map<String, DatafeedConfig> datafeeds = instance.getDatafeeds(); Map<String, DatafeedConfig> datafeeds = instance.getDatafeeds();
MlMetadata.Builder metadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder metadataBuilder = new MlMetadata.Builder();
@ -408,7 +407,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
metadataBuilder.putJob(entry.getValue(), true); metadataBuilder.putJob(entry.getValue(), true);
} }
for (Map.Entry<String, DatafeedConfig> entry : datafeeds.entrySet()) { for (Map.Entry<String, DatafeedConfig> entry : datafeeds.entrySet()) {
metadataBuilder.putDatafeed(entry.getValue()); metadataBuilder.putDatafeed(entry.getValue(), null);
} }
switch (between(0, 1)) { switch (between(0, 1)) {
@ -429,7 +428,7 @@ public class MlMetadataTests extends AbstractSerializingTestCase<MlMetadata> {
} }
randomJob = new Job.Builder(randomJob).setAnalysisConfig(analysisConfig).build(); randomJob = new Job.Builder(randomJob).setAnalysisConfig(analysisConfig).build();
metadataBuilder.putJob(randomJob, false); metadataBuilder.putJob(randomJob, false);
metadataBuilder.putDatafeed(datafeedConfig); metadataBuilder.putDatafeed(datafeedConfig, null);
break; break;
default: default:
throw new AssertionError("Illegal randomisation branch"); throw new AssertionError("Illegal randomisation branch");

View File

@ -80,7 +80,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
MlMetadata.Builder mlBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlBuilder = new MlMetadata.Builder();
mlBuilder.putJob(BaseMlIntegTestCase.createScheduledJob("job_id").build(new Date()), false); mlBuilder.putJob(BaseMlIntegTestCase.createScheduledJob("job_id").build(new Date()), false);
mlBuilder.putDatafeed(BaseMlIntegTestCase.createDatafeed("datafeed_id", "job_id", mlBuilder.putDatafeed(BaseMlIntegTestCase.createDatafeed("datafeed_id", "job_id",
Collections.singletonList("*"))); Collections.singletonList("*")), null);
final PersistentTasksCustomMetaData.Builder startDataFeedTaskBuilder = PersistentTasksCustomMetaData.builder(); final PersistentTasksCustomMetaData.Builder startDataFeedTaskBuilder = PersistentTasksCustomMetaData.builder();
addJobTask("job_id", null, JobState.OPENED, startDataFeedTaskBuilder); addJobTask("job_id", null, JobState.OPENED, startDataFeedTaskBuilder);
addTask("datafeed_id", 0L, null, DatafeedState.STARTED, startDataFeedTaskBuilder); addTask("datafeed_id", 0L, null, DatafeedState.STARTED, startDataFeedTaskBuilder);
@ -147,7 +147,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
request.setForce(true); request.setForce(true);
CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs);
assertEquals(Arrays.asList("job_id_1", "job_id_2", "job_id_3"), openJobs); assertEquals(Arrays.asList("job_id_1", "job_id_2", "job_id_3"), openJobs);
assertEquals(Arrays.asList("job_id_4"), closingJobs); assertEquals(Collections.singletonList("job_id_4"), closingJobs);
request.setForce(false); request.setForce(false);
expectThrows(ElasticsearchStatusException.class, expectThrows(ElasticsearchStatusException.class,
@ -171,7 +171,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
CloseJobAction.Request request = new CloseJobAction.Request("job_id_1"); CloseJobAction.Request request = new CloseJobAction.Request("job_id_1");
CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs);
assertEquals(Arrays.asList("job_id_1"), openJobs); assertEquals(Collections.singletonList("job_id_1"), openJobs);
assertEquals(Collections.emptyList(), closingJobs); assertEquals(Collections.emptyList(), closingJobs);
// Job without task is closed // Job without task is closed
@ -219,7 +219,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
request.setForce(true); request.setForce(true);
CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(request, cs1, openJobs, closingJobs);
assertEquals(Arrays.asList("job_id_failed"), openJobs); assertEquals(Collections.singletonList("job_id_failed"), openJobs);
assertEquals(Collections.emptyList(), closingJobs); assertEquals(Collections.emptyList(), closingJobs);
openJobs.clear(); openJobs.clear();
@ -252,7 +252,7 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("_all"), cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("_all"), cs1, openJobs, closingJobs);
assertEquals(Arrays.asList("job_id_open-1", "job_id_open-2"), openJobs); assertEquals(Arrays.asList("job_id_open-1", "job_id_open-2"), openJobs);
assertEquals(Arrays.asList("job_id_closing"), closingJobs); assertEquals(Collections.singletonList("job_id_closing"), closingJobs);
openJobs.clear(); openJobs.clear();
closingJobs.clear(); closingJobs.clear();
@ -264,12 +264,12 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("job_id_closing"), cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("job_id_closing"), cs1, openJobs, closingJobs);
assertEquals(Collections.emptyList(), openJobs); assertEquals(Collections.emptyList(), openJobs);
assertEquals(Arrays.asList("job_id_closing"), closingJobs); assertEquals(Collections.singletonList("job_id_closing"), closingJobs);
openJobs.clear(); openJobs.clear();
closingJobs.clear(); closingJobs.clear();
CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("job_id_open-1"), cs1, openJobs, closingJobs); CloseJobAction.resolveAndValidateJobId(new CloseJobAction.Request("job_id_open-1"), cs1, openJobs, closingJobs);
assertEquals(Arrays.asList("job_id_open-1"), openJobs); assertEquals(Collections.singletonList("job_id_open-1"), openJobs);
assertEquals(Collections.emptyList(), closingJobs); assertEquals(Collections.emptyList(), closingJobs);
openJobs.clear(); openJobs.clear();
closingJobs.clear(); closingJobs.clear();
@ -316,8 +316,8 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
} }
public void testBuildWaitForCloseRequest() { public void testBuildWaitForCloseRequest() {
List<String> openJobIds = Arrays.asList(new String[] {"openjob1", "openjob2"}); List<String> openJobIds = Arrays.asList("openjob1", "openjob2");
List<String> closingJobIds = Arrays.asList(new String[] {"closingjob1"}); List<String> closingJobIds = Collections.singletonList("closingjob1");
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
addJobTask("openjob1", null, JobState.OPENED, tasksBuilder); addJobTask("openjob1", null, JobState.OPENED, tasksBuilder);
@ -343,4 +343,4 @@ public class CloseJobActionRequestTests extends AbstractStreamableXContentTestCa
tasks.updateTaskStatus(MlMetadata.datafeedTaskId(datafeedId), state); tasks.updateTaskStatus(MlMetadata.datafeedTaskId(datafeedId), state);
} }
} }

View File

@ -43,7 +43,7 @@ public class StartDatafeedActionTests extends ESTestCase {
PersistentTasksCustomMetaData tasks = PersistentTasksCustomMetaData.builder().build(); PersistentTasksCustomMetaData tasks = PersistentTasksCustomMetaData.builder().build();
DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build(); DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build();
MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1) MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1)
.putDatafeed(datafeedConfig1) .putDatafeed(datafeedConfig1, null)
.build(); .build();
Exception e = expectThrows(ElasticsearchStatusException.class, Exception e = expectThrows(ElasticsearchStatusException.class,
() -> StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks)); () -> StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks));
@ -60,7 +60,7 @@ public class StartDatafeedActionTests extends ESTestCase {
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); PersistentTasksCustomMetaData tasks = tasksBuilder.build();
DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build(); DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build();
MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1) MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1)
.putDatafeed(datafeedConfig1) .putDatafeed(datafeedConfig1, null)
.build(); .build();
StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks); StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks);
@ -76,7 +76,7 @@ public class StartDatafeedActionTests extends ESTestCase {
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); PersistentTasksCustomMetaData tasks = tasksBuilder.build();
DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build(); DatafeedConfig datafeedConfig1 = DatafeedManagerTests.createDatafeedConfig("foo-datafeed", "job_id").build();
MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1) MlMetadata mlMetadata2 = new MlMetadata.Builder(mlMetadata1)
.putDatafeed(datafeedConfig1) .putDatafeed(datafeedConfig1, null)
.build(); .build();
StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks); StartDatafeedAction.validate("foo-datafeed", mlMetadata2, tasks);

View File

@ -20,7 +20,6 @@ import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData.Assignment; import org.elasticsearch.xpack.persistent.PersistentTasksCustomMetaData.Assignment;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List; import java.util.List;
@ -66,7 +65,7 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
tasksBuilder.addTask(MlMetadata.datafeedTaskId("foo"), StartDatafeedAction.TASK_NAME, tasksBuilder.addTask(MlMetadata.datafeedTaskId("foo"), StartDatafeedAction.TASK_NAME,
new StartDatafeedAction.DatafeedParams("foo", 0L), new Assignment("node_id", "")); new StartDatafeedAction.DatafeedParams("foo", 0L), new Assignment("node_id", ""));
tasksBuilder.updateTaskStatus(MlMetadata.datafeedTaskId("foo"), DatafeedState.STARTED); tasksBuilder.updateTaskStatus(MlMetadata.datafeedTaskId("foo"), DatafeedState.STARTED);
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); tasksBuilder.build();
Job job = createDatafeedJob().build(new Date()); Job job = createDatafeedJob().build(new Date());
MlMetadata mlMetadata1 = new MlMetadata.Builder().putJob(job, false).build(); MlMetadata mlMetadata1 = new MlMetadata.Builder().putJob(job, false).build();
@ -76,7 +75,7 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
DatafeedConfig datafeedConfig = createDatafeedConfig("foo", "job_id").build(); DatafeedConfig datafeedConfig = createDatafeedConfig("foo", "job_id").build();
MlMetadata mlMetadata2 = new MlMetadata.Builder().putJob(job, false) MlMetadata mlMetadata2 = new MlMetadata.Builder().putJob(job, false)
.putDatafeed(datafeedConfig) .putDatafeed(datafeedConfig, null)
.build(); .build();
StopDatafeedAction.validateDatafeedTask("foo", mlMetadata2); StopDatafeedAction.validateDatafeedTask("foo", mlMetadata2);
} }
@ -88,12 +87,12 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
addTask("datafeed_1", 0L, "node-1", DatafeedState.STARTED, tasksBuilder); addTask("datafeed_1", 0L, "node-1", DatafeedState.STARTED, tasksBuilder);
Job job = BaseMlIntegTestCase.createScheduledJob("job_id_1").build(new Date()); Job job = BaseMlIntegTestCase.createScheduledJob("job_id_1").build(new Date());
DatafeedConfig datafeedConfig = createDatafeedConfig("datafeed_1", "job_id_1").build(); DatafeedConfig datafeedConfig = createDatafeedConfig("datafeed_1", "job_id_1").build();
mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig); mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig, null);
addTask("datafeed_2", 0L, "node-1", DatafeedState.STOPPED, tasksBuilder); addTask("datafeed_2", 0L, "node-1", DatafeedState.STOPPED, tasksBuilder);
job = BaseMlIntegTestCase.createScheduledJob("job_id_2").build(new Date()); job = BaseMlIntegTestCase.createScheduledJob("job_id_2").build(new Date());
datafeedConfig = createDatafeedConfig("datafeed_2", "job_id_2").build(); datafeedConfig = createDatafeedConfig("datafeed_2", "job_id_2").build();
mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig); mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig, null);
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); PersistentTasksCustomMetaData tasks = tasksBuilder.build();
MlMetadata mlMetadata = mlMetadataBuilder.build(); MlMetadata mlMetadata = mlMetadataBuilder.build();
@ -102,7 +101,7 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
List<String> stoppingDatafeeds = new ArrayList<>(); List<String> stoppingDatafeeds = new ArrayList<>();
StopDatafeedAction.resolveDataFeedIds(new StopDatafeedAction.Request("datafeed_1"), mlMetadata, tasks, startedDatafeeds, StopDatafeedAction.resolveDataFeedIds(new StopDatafeedAction.Request("datafeed_1"), mlMetadata, tasks, startedDatafeeds,
stoppingDatafeeds); stoppingDatafeeds);
assertEquals(Arrays.asList("datafeed_1"), startedDatafeeds); assertEquals(Collections.singletonList("datafeed_1"), startedDatafeeds);
assertEquals(Collections.emptyList(), stoppingDatafeeds); assertEquals(Collections.emptyList(), stoppingDatafeeds);
startedDatafeeds.clear(); startedDatafeeds.clear();
@ -120,17 +119,17 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
addTask("datafeed_1", 0L, "node-1", DatafeedState.STARTED, tasksBuilder); addTask("datafeed_1", 0L, "node-1", DatafeedState.STARTED, tasksBuilder);
Job job = BaseMlIntegTestCase.createScheduledJob("job_id_1").build(new Date()); Job job = BaseMlIntegTestCase.createScheduledJob("job_id_1").build(new Date());
DatafeedConfig datafeedConfig = createDatafeedConfig("datafeed_1", "job_id_1").build(); DatafeedConfig datafeedConfig = createDatafeedConfig("datafeed_1", "job_id_1").build();
mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig); mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig, null);
addTask("datafeed_2", 0L, "node-1", DatafeedState.STOPPED, tasksBuilder); addTask("datafeed_2", 0L, "node-1", DatafeedState.STOPPED, tasksBuilder);
job = BaseMlIntegTestCase.createScheduledJob("job_id_2").build(new Date()); job = BaseMlIntegTestCase.createScheduledJob("job_id_2").build(new Date());
datafeedConfig = createDatafeedConfig("datafeed_2", "job_id_2").build(); datafeedConfig = createDatafeedConfig("datafeed_2", "job_id_2").build();
mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig); mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig, null);
addTask("datafeed_3", 0L, "node-1", DatafeedState.STOPPING, tasksBuilder); addTask("datafeed_3", 0L, "node-1", DatafeedState.STOPPING, tasksBuilder);
job = BaseMlIntegTestCase.createScheduledJob("job_id_3").build(new Date()); job = BaseMlIntegTestCase.createScheduledJob("job_id_3").build(new Date());
datafeedConfig = createDatafeedConfig("datafeed_3", "job_id_3").build(); datafeedConfig = createDatafeedConfig("datafeed_3", "job_id_3").build();
mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig); mlMetadataBuilder.putJob(job, false).putDatafeed(datafeedConfig, null);
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); PersistentTasksCustomMetaData tasks = tasksBuilder.build();
MlMetadata mlMetadata = mlMetadataBuilder.build(); MlMetadata mlMetadata = mlMetadataBuilder.build();
@ -139,8 +138,8 @@ public class StopDatafeedActionRequestTests extends AbstractStreamableXContentTe
List<String> stoppingDatafeeds = new ArrayList<>(); List<String> stoppingDatafeeds = new ArrayList<>();
StopDatafeedAction.resolveDataFeedIds(new StopDatafeedAction.Request("_all"), mlMetadata, tasks, startedDatafeeds, StopDatafeedAction.resolveDataFeedIds(new StopDatafeedAction.Request("_all"), mlMetadata, tasks, startedDatafeeds,
stoppingDatafeeds); stoppingDatafeeds);
assertEquals(Arrays.asList("datafeed_1"), startedDatafeeds); assertEquals(Collections.singletonList("datafeed_1"), startedDatafeeds);
assertEquals(Arrays.asList("datafeed_3"), stoppingDatafeeds); assertEquals(Collections.singletonList("datafeed_3"), stoppingDatafeeds);
startedDatafeeds.clear(); startedDatafeeds.clear();
stoppingDatafeeds.clear(); stoppingDatafeeds.clear();

View File

@ -16,8 +16,10 @@ import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -80,7 +82,7 @@ public class DatafeedManagerTests extends ESTestCase {
Job job = createDatafeedJob().build(new Date()); Job job = createDatafeedJob().build(new Date());
mlMetadata.putJob(job, false); mlMetadata.putJob(job, false);
DatafeedConfig datafeed = createDatafeedConfig("datafeed_id", job.getId()).build(); DatafeedConfig datafeed = createDatafeedConfig("datafeed_id", job.getId()).build();
mlMetadata.putDatafeed(datafeed); mlMetadata.putDatafeed(datafeed, null);
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
addJobTask(job.getId(), "node_id", JobState.OPENED, tasksBuilder); addJobTask(job.getId(), "node_id", JobState.OPENED, tasksBuilder);
PersistentTasksCustomMetaData tasks = tasksBuilder.build(); PersistentTasksCustomMetaData tasks = tasksBuilder.build();
@ -109,6 +111,7 @@ public class DatafeedManagerTests extends ESTestCase {
auditor = mock(Auditor.class); auditor = mock(Auditor.class);
threadPool = mock(ThreadPool.class); threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
ExecutorService executorService = mock(ExecutorService.class); ExecutorService executorService = mock(ExecutorService.class);
doAnswer(invocation -> { doAnswer(invocation -> {
((Runnable) invocation.getArguments()[0]).run(); ((Runnable) invocation.getArguments()[0]).run();
@ -248,7 +251,7 @@ public class DatafeedManagerTests extends ESTestCase {
} }
} }
public void testDatafeedTaskWaitsUntilJobIsOpened() throws Exception { public void testDatafeedTaskWaitsUntilJobIsOpened() {
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder); addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder);
ClusterState.Builder cs = ClusterState.builder(clusterService.state()) ClusterState.Builder cs = ClusterState.builder(clusterService.state())
@ -288,7 +291,7 @@ public class DatafeedManagerTests extends ESTestCase {
verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_THREAD_POOL_NAME); verify(threadPool, times(1)).executor(MachineLearning.DATAFEED_THREAD_POOL_NAME);
} }
public void testDatafeedTaskStopsBecauseJobFailedWhileOpening() throws Exception { public void testDatafeedTaskStopsBecauseJobFailedWhileOpening() {
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder); addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder);
ClusterState.Builder cs = ClusterState.builder(clusterService.state()) ClusterState.Builder cs = ClusterState.builder(clusterService.state())
@ -316,7 +319,7 @@ public class DatafeedManagerTests extends ESTestCase {
verify(task).stop("job_never_opened", TimeValue.timeValueSeconds(20)); verify(task).stop("job_never_opened", TimeValue.timeValueSeconds(20));
} }
public void testDatafeedGetsStoppedWhileWaitingForJobToOpen() throws Exception { public void testDatafeedGetsStoppedWhileWaitingForJobToOpen() {
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder); addJobTask("job_id", "node_id", JobState.OPENING, tasksBuilder);
ClusterState.Builder cs = ClusterState.builder(clusterService.state()) ClusterState.Builder cs = ClusterState.builder(clusterService.state())

View File

@ -63,11 +63,11 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
.build(); .build();
} }
public void testSelectNode_GivenJobIsOpened() throws Exception { public void testSelectNode_GivenJobIsOpened() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -81,11 +81,11 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated(); new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated();
} }
public void testSelectNode_GivenJobIsOpening() throws Exception { public void testSelectNode_GivenJobIsOpening() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -99,13 +99,13 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated(); new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated();
} }
public void testNoJobTask() throws Exception { public void testNoJobTask() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
// Using wildcard index name to test for index resolving as well // Using wildcard index name to test for index resolving as well
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
tasks = PersistentTasksCustomMetaData.builder().build(); tasks = PersistentTasksCustomMetaData.builder().build();
@ -123,11 +123,11 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
+ "[cannot start datafeed [datafeed_id], because job's [job_id] state is [closed] while state [opened] is required]")); + "[cannot start datafeed [datafeed_id], because job's [job_id] state is [closed] while state [opened] is required]"));
} }
public void testSelectNode_GivenJobFailedOrClosed() throws Exception { public void testSelectNode_GivenJobFailedOrClosed() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -149,13 +149,13 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
+ "] while state [opened] is required]")); + "] while state [opened] is required]"));
} }
public void testShardUnassigned() throws Exception { public void testShardUnassigned() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
// Using wildcard index name to test for index resolving as well // Using wildcard index name to test for index resolving as well
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -175,13 +175,13 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated(); new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated();
} }
public void testShardNotAllActive() throws Exception { public void testShardNotAllActive() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
// Using wildcard index name to test for index resolving as well // Using wildcard index name to test for index resolving as well
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("fo*")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -202,11 +202,11 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated(); new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated();
} }
public void testIndexDoesntExist() throws Exception { public void testIndexDoesntExist() {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("not_foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("not_foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -230,7 +230,7 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
String nodeId = randomBoolean() ? "node_id2" : null; String nodeId = randomBoolean() ? "node_id2" : null;
@ -261,14 +261,14 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated(); new DatafeedNodeSelector(clusterState, resolver, "datafeed_id").checkDatafeedTaskCanBeCreated();
} }
public void testSelectNode_GivenJobOpeningAndIndexDoesNotExist() throws Exception { public void testSelectNode_GivenJobOpeningAndIndexDoesNotExist() {
// Here we test that when there are 2 problems, the most critical gets reported first. // Here we test that when there are 2 problems, the most critical gets reported first.
// In this case job is Opening (non-critical) and the index does not exist (critical) // In this case job is Opening (non-critical) and the index does not exist (critical)
MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder(); MlMetadata.Builder mlMetadataBuilder = new MlMetadata.Builder();
Job job = createScheduledJob("job_id").build(new Date()); Job job = createScheduledJob("job_id").build(new Date());
mlMetadataBuilder.putJob(job, false); mlMetadataBuilder.putJob(job, false);
mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("not_foo"))); mlMetadataBuilder.putDatafeed(createDatafeed("datafeed_id", job.getId(), Collections.singletonList("not_foo")), null);
mlMetadata = mlMetadataBuilder.build(); mlMetadata = mlMetadataBuilder.build();
PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder();
@ -339,4 +339,4 @@ public class DatafeedNodeSelectorTests extends ESTestCase {
return new RoutingTable.Builder().add(rtBuilder).build(); return new RoutingTable.Builder().add(rtBuilder).build();
} }
} }

View File

@ -25,9 +25,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder.ScriptField;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractSerializingTestCase;
import org.elasticsearch.xpack.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.ml.datafeed.ChunkingConfig.Mode;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -93,7 +91,7 @@ public class DatafeedUpdateTests extends AbstractSerializingTestCase<DatafeedUpd
} }
@Override @Override
protected DatafeedUpdate doParseInstance(XContentParser parser) throws IOException { protected DatafeedUpdate doParseInstance(XContentParser parser) {
return DatafeedUpdate.PARSER.apply(parser, null).build(); return DatafeedUpdate.PARSER.apply(parser, null).build();
} }
@ -111,12 +109,12 @@ public class DatafeedUpdateTests extends AbstractSerializingTestCase<DatafeedUpd
public void testApply_failBecauseTargetDatafeedHasDifferentId() { public void testApply_failBecauseTargetDatafeedHasDifferentId() {
DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo"); DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo");
expectThrows(IllegalArgumentException.class, () -> createRandomized(datafeed.getId() + "_2").apply(datafeed)); expectThrows(IllegalArgumentException.class, () -> createRandomized(datafeed.getId() + "_2").apply(datafeed, null));
} }
public void testApply_givenEmptyUpdate() { public void testApply_givenEmptyUpdate() {
DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo"); DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo");
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed); DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, null);
assertThat(datafeed, equalTo(updatedDatafeed)); assertThat(datafeed, equalTo(updatedDatafeed));
} }
@ -127,7 +125,7 @@ public class DatafeedUpdateTests extends AbstractSerializingTestCase<DatafeedUpd
DatafeedUpdate.Builder updated = new DatafeedUpdate.Builder(datafeed.getId()); DatafeedUpdate.Builder updated = new DatafeedUpdate.Builder(datafeed.getId());
updated.setScrollSize(datafeed.getScrollSize() + 1); updated.setScrollSize(datafeed.getScrollSize() + 1);
DatafeedConfig updatedDatafeed = update.build().apply(datafeed); DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
DatafeedConfig.Builder expectedDatafeed = new DatafeedConfig.Builder(datafeed); DatafeedConfig.Builder expectedDatafeed = new DatafeedConfig.Builder(datafeed);
expectedDatafeed.setScrollSize(datafeed.getScrollSize() + 1); expectedDatafeed.setScrollSize(datafeed.getScrollSize() + 1);
@ -136,40 +134,40 @@ public class DatafeedUpdateTests extends AbstractSerializingTestCase<DatafeedUpd
public void testApply_givenFullUpdateNoAggregations() { public void testApply_givenFullUpdateNoAggregations() {
DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("foo", "foo-feed"); DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("foo", "foo-feed");
datafeedBuilder.setIndices(Arrays.asList("i_1")); datafeedBuilder.setIndices(Collections.singletonList("i_1"));
datafeedBuilder.setTypes(Arrays.asList("t_1")); datafeedBuilder.setTypes(Collections.singletonList("t_1"));
DatafeedConfig datafeed = datafeedBuilder.build(); DatafeedConfig datafeed = datafeedBuilder.build();
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeed.getId()); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeed.getId());
update.setJobId("bar"); update.setJobId("bar");
update.setIndices(Arrays.asList("i_2")); update.setIndices(Collections.singletonList("i_2"));
update.setTypes(Arrays.asList("t_2")); update.setTypes(Collections.singletonList("t_2"));
update.setQueryDelay(TimeValue.timeValueSeconds(42)); update.setQueryDelay(TimeValue.timeValueSeconds(42));
update.setFrequency(TimeValue.timeValueSeconds(142)); update.setFrequency(TimeValue.timeValueSeconds(142));
update.setQuery(QueryBuilders.termQuery("a", "b")); update.setQuery(QueryBuilders.termQuery("a", "b"));
update.setScriptFields(Arrays.asList(new SearchSourceBuilder.ScriptField("a", mockScript("b"), false))); update.setScriptFields(Collections.singletonList(new SearchSourceBuilder.ScriptField("a", mockScript("b"), false)));
update.setScrollSize(8000); update.setScrollSize(8000);
update.setChunkingConfig(ChunkingConfig.newManual(TimeValue.timeValueHours(1))); update.setChunkingConfig(ChunkingConfig.newManual(TimeValue.timeValueHours(1)));
DatafeedConfig updatedDatafeed = update.build().apply(datafeed); DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
assertThat(updatedDatafeed.getJobId(), equalTo("bar")); assertThat(updatedDatafeed.getJobId(), equalTo("bar"));
assertThat(updatedDatafeed.getIndices(), equalTo(Arrays.asList("i_2"))); assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_2")));
assertThat(updatedDatafeed.getTypes(), equalTo(Arrays.asList("t_2"))); assertThat(updatedDatafeed.getTypes(), equalTo(Collections.singletonList("t_2")));
assertThat(updatedDatafeed.getQueryDelay(), equalTo(TimeValue.timeValueSeconds(42))); assertThat(updatedDatafeed.getQueryDelay(), equalTo(TimeValue.timeValueSeconds(42)));
assertThat(updatedDatafeed.getFrequency(), equalTo(TimeValue.timeValueSeconds(142))); assertThat(updatedDatafeed.getFrequency(), equalTo(TimeValue.timeValueSeconds(142)));
assertThat(updatedDatafeed.getQuery(), equalTo(QueryBuilders.termQuery("a", "b"))); assertThat(updatedDatafeed.getQuery(), equalTo(QueryBuilders.termQuery("a", "b")));
assertThat(updatedDatafeed.hasAggregations(), is(false)); assertThat(updatedDatafeed.hasAggregations(), is(false));
assertThat(updatedDatafeed.getScriptFields(), assertThat(updatedDatafeed.getScriptFields(),
equalTo(Arrays.asList(new SearchSourceBuilder.ScriptField("a", mockScript("b"), false)))); equalTo(Collections.singletonList(new SearchSourceBuilder.ScriptField("a", mockScript("b"), false))));
assertThat(updatedDatafeed.getScrollSize(), equalTo(8000)); assertThat(updatedDatafeed.getScrollSize(), equalTo(8000));
assertThat(updatedDatafeed.getChunkingConfig(), equalTo(ChunkingConfig.newManual(TimeValue.timeValueHours(1)))); assertThat(updatedDatafeed.getChunkingConfig(), equalTo(ChunkingConfig.newManual(TimeValue.timeValueHours(1))));
} }
public void testApply_givenAggregations() { public void testApply_givenAggregations() {
DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("foo", "foo-feed"); DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("foo", "foo-feed");
datafeedBuilder.setIndices(Arrays.asList("i_1")); datafeedBuilder.setIndices(Collections.singletonList("i_1"));
datafeedBuilder.setTypes(Arrays.asList("t_1")); datafeedBuilder.setTypes(Collections.singletonList("t_1"));
DatafeedConfig datafeed = datafeedBuilder.build(); DatafeedConfig datafeed = datafeedBuilder.build();
DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeed.getId()); DatafeedUpdate.Builder update = new DatafeedUpdate.Builder(datafeed.getId());
@ -177,17 +175,17 @@ public class DatafeedUpdateTests extends AbstractSerializingTestCase<DatafeedUpd
update.setAggregations(new AggregatorFactories.Builder().addAggregator( update.setAggregations(new AggregatorFactories.Builder().addAggregator(
AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime))); AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime)));
DatafeedConfig updatedDatafeed = update.build().apply(datafeed); DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
assertThat(updatedDatafeed.getIndices(), equalTo(Arrays.asList("i_1"))); assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_1")));
assertThat(updatedDatafeed.getTypes(), equalTo(Arrays.asList("t_1"))); assertThat(updatedDatafeed.getTypes(), equalTo(Collections.singletonList("t_1")));
assertThat(updatedDatafeed.getAggregations(), assertThat(updatedDatafeed.getAggregations(),
equalTo(new AggregatorFactories.Builder().addAggregator( equalTo(new AggregatorFactories.Builder().addAggregator(
AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime)))); AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime))));
} }
@Override @Override
protected DatafeedUpdate mutateInstance(DatafeedUpdate instance) throws IOException { protected DatafeedUpdate mutateInstance(DatafeedUpdate instance) {
DatafeedUpdate.Builder builder = new DatafeedUpdate.Builder(instance); DatafeedUpdate.Builder builder = new DatafeedUpdate.Builder(instance);
switch (between(0, 10)) { switch (between(0, 10)) {
case 0: case 0:

View File

@ -186,7 +186,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
assertThat(capturedSearchRequests.size(), equalTo(1)); assertThat(capturedSearchRequests.size(), equalTo(1));
} }
public void testExtractionGivenResponseHasMultipleTopLevelAggs() throws IOException { public void testExtractionGivenResponseHasMultipleTopLevelAggs() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L); TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
Histogram histogram1 = mock(Histogram.class); Histogram histogram1 = mock(Histogram.class);
@ -203,7 +203,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
assertThat(e.getMessage(), containsString("Multiple top level aggregations not supported; found: [hist_1, hist_2]")); assertThat(e.getMessage(), containsString("Multiple top level aggregations not supported; found: [hist_1, hist_2]"));
} }
public void testExtractionGivenCancelBeforeNext() throws IOException { public void testExtractionGivenCancelBeforeNext() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 4000L); TestDataExtractor extractor = new TestDataExtractor(1000L, 4000L);
SearchResponse response = createSearchResponse("time", Collections.emptyList()); SearchResponse response = createSearchResponse("time", Collections.emptyList());
extractor.setNextResponse(response); extractor.setNextResponse(response);
@ -249,7 +249,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
expectThrows(IOException.class, extractor::next); expectThrows(IOException.class, extractor::next);
} }
public void testExtractionGivenSearchResponseHasShardFailures() throws IOException { public void testExtractionGivenSearchResponseHasShardFailures() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L); TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithShardFailures()); extractor.setNextResponse(createResponseWithShardFailures());
@ -257,7 +257,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
IOException e = expectThrows(IOException.class, extractor::next); IOException e = expectThrows(IOException.class, extractor::next);
} }
public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() throws IOException { public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L); TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithUnavailableShards(2)); extractor.setNextResponse(createResponseWithUnavailableShards(2));
@ -267,7 +267,8 @@ public class AggregationDataExtractorTests extends ESTestCase {
} }
private AggregationDataExtractorContext createContext(long start, long end) { private AggregationDataExtractorContext createContext(long start, long end) {
return new AggregationDataExtractorContext(jobId, timeField, fields, indices, types, query, aggs, start, end, true); return new AggregationDataExtractorContext(jobId, timeField, fields, indices, types, query, aggs, start, end, true,
Collections.emptyMap());
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View File

@ -321,7 +321,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
assertThat(searchRequest, containsString("\"from\":200000,\"to\":400000")); assertThat(searchRequest, containsString("\"from\":200000,\"to\":400000"));
} }
public void testCancelGivenNextWasNeverCalled() throws IOException { public void testCancelGivenNextWasNeverCalled() {
chunkSpan = TimeValue.timeValueSeconds(1); chunkSpan = TimeValue.timeValueSeconds(1);
TestDataExtractor extractor = new TestDataExtractor(1000L, 2300L); TestDataExtractor extractor = new TestDataExtractor(1000L, 2300L);
extractor.setNextResponse(createSearchResponse(10L, 1000L, 2200L)); extractor.setNextResponse(createSearchResponse(10L, 1000L, 2200L));
@ -446,7 +446,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
private ChunkedDataExtractorContext createContext(long start, long end) { private ChunkedDataExtractorContext createContext(long start, long end) {
return new ChunkedDataExtractorContext(jobId, timeField, indices, types, query, scrollSize, start, end, chunkSpan, return new ChunkedDataExtractorContext(jobId, timeField, indices, types, query, scrollSize, start, end, chunkSpan,
ChunkedDataExtractorFactory.newIdentityTimeAligner()); ChunkedDataExtractorFactory.newIdentityTimeAligner(), Collections.emptyMap());
} }
private static class StubSubExtractor implements DataExtractor { private static class StubSubExtractor implements DataExtractor {
@ -465,7 +465,7 @@ public class ChunkedDataExtractorTests extends ESTestCase {
} }
@Override @Override
public Optional<InputStream> next() throws IOException { public Optional<InputStream> next() {
if (streams.isEmpty()) { if (streams.isEmpty()) {
hasNext = false; hasNext = false;
return Optional.empty(); return Optional.empty();

View File

@ -15,6 +15,8 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
@ -24,6 +26,7 @@ import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -117,7 +120,10 @@ public class ScrollDataExtractorTests extends ESTestCase {
@Before @Before
public void setUpTests() { public void setUpTests() {
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
client = mock(Client.class); client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
capturedSearchRequests = new ArrayList<>(); capturedSearchRequests = new ArrayList<>();
capturedContinueScrollIds = new ArrayList<>(); capturedContinueScrollIds = new ArrayList<>();
jobId = "test-job"; jobId = "test-job";
@ -269,7 +275,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
extractor.setNextResponse(createErrorResponse()); extractor.setNextResponse(createErrorResponse());
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, () -> extractor.next()); expectThrows(IOException.class, extractor::next);
} }
public void testExtractionGivenContinueScrollResponseHasError() throws IOException { public void testExtractionGivenContinueScrollResponseHasError() throws IOException {
@ -288,7 +294,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
extractor.setNextResponse(createErrorResponse()); extractor.setNextResponse(createErrorResponse());
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, () -> extractor.next()); expectThrows(IOException.class, extractor::next);
} }
public void testExtractionGivenInitSearchResponseHasShardFailures() throws IOException { public void testExtractionGivenInitSearchResponseHasShardFailures() throws IOException {
@ -297,7 +303,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
extractor.setNextResponse(createResponseWithShardFailures()); extractor.setNextResponse(createResponseWithShardFailures());
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, () -> extractor.next()); expectThrows(IOException.class, extractor::next);
} }
public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() throws IOException { public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() throws IOException {
@ -306,7 +312,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
extractor.setNextResponse(createResponseWithUnavailableShards(1)); extractor.setNextResponse(createResponseWithUnavailableShards(1));
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
IOException e = expectThrows(IOException.class, () -> extractor.next()); IOException e = expectThrows(IOException.class, extractor::next);
assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [1] unavailable shards")); assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [1] unavailable shards"));
} }
@ -333,7 +339,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
assertThat(output.isPresent(), is(true)); assertThat(output.isPresent(), is(true));
// A second failure is not tolerated // A second failure is not tolerated
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, () -> extractor.next()); expectThrows(IOException.class, extractor::next);
} }
public void testResetScollUsesLastResultTimestamp() throws IOException { public void testResetScollUsesLastResultTimestamp() throws IOException {
@ -389,7 +395,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
assertEquals(new Long(1400L), extractor.getLastTimestamp()); assertEquals(new Long(1400L), extractor.getLastTimestamp());
// A second failure is not tolerated // A second failure is not tolerated
assertThat(extractor.hasNext(), is(true)); assertThat(extractor.hasNext(), is(true));
expectThrows(SearchPhaseExecutionException.class, () -> extractor.next()); expectThrows(SearchPhaseExecutionException.class, extractor::next);
} }
public void testSearchPhaseExecutionExceptionOnInitScroll() throws IOException { public void testSearchPhaseExecutionExceptionOnInitScroll() throws IOException {
@ -398,7 +404,7 @@ public class ScrollDataExtractorTests extends ESTestCase {
extractor.setNextResponse(createResponseWithShardFailures()); extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponse(createResponseWithShardFailures()); extractor.setNextResponse(createResponseWithShardFailures());
expectThrows(IOException.class, () -> extractor.next()); expectThrows(IOException.class, extractor::next);
List<String> capturedClearScrollIds = getCapturedClearScrollIds(); List<String> capturedClearScrollIds = getCapturedClearScrollIds();
assertThat(capturedClearScrollIds.isEmpty(), is(true)); assertThat(capturedClearScrollIds.isEmpty(), is(true));
@ -412,8 +418,8 @@ public class ScrollDataExtractorTests extends ESTestCase {
"script2", new Script(ScriptType.INLINE, "painless", "return domainSplit('foo.com', params);", emptyMap()), false); "script2", new Script(ScriptType.INLINE, "painless", "return domainSplit('foo.com', params);", emptyMap()), false);
List<SearchSourceBuilder.ScriptField> sFields = Arrays.asList(withoutSplit, withSplit); List<SearchSourceBuilder.ScriptField> sFields = Arrays.asList(withoutSplit, withSplit);
ScrollDataExtractorContext context = new ScrollDataExtractorContext(jobId, extractedFields, indices, ScrollDataExtractorContext context = new ScrollDataExtractorContext(jobId, extractedFields, indices,
types, query, sFields, scrollSize, 1000, 2000); types, query, sFields, scrollSize, 1000, 2000, Collections.emptyMap());
TestDataExtractor extractor = new TestDataExtractor(context); TestDataExtractor extractor = new TestDataExtractor(context);
@ -460,7 +466,8 @@ public class ScrollDataExtractorTests extends ESTestCase {
} }
private ScrollDataExtractorContext createContext(long start, long end) { private ScrollDataExtractorContext createContext(long start, long end) {
return new ScrollDataExtractorContext(jobId, extractedFields, indices, types, query, scriptFields, scrollSize, start, end); return new ScrollDataExtractorContext(jobId, extractedFields, indices, types, query, scriptFields, scrollSize, start, end,
Collections.emptyMap());
} }
private SearchResponse createEmptySearchResponse() { private SearchResponse createEmptySearchResponse() {
@ -475,9 +482,9 @@ public class ScrollDataExtractorTests extends ESTestCase {
for (int i = 0; i < timestamps.size(); i++) { for (int i = 0; i < timestamps.size(); i++) {
SearchHit hit = new SearchHit(randomInt()); SearchHit hit = new SearchHit(randomInt());
Map<String, DocumentField> fields = new HashMap<>(); Map<String, DocumentField> fields = new HashMap<>();
fields.put(extractedFields.timeField(), new DocumentField("time", Arrays.asList(timestamps.get(i)))); fields.put(extractedFields.timeField(), new DocumentField("time", Collections.singletonList(timestamps.get(i))));
fields.put("field_1", new DocumentField("field_1", Arrays.asList(field1Values.get(i)))); fields.put("field_1", new DocumentField("field_1", Collections.singletonList(field1Values.get(i))));
fields.put("field_2", new DocumentField("field_2", Arrays.asList(field2Values.get(i)))); fields.put("field_2", new DocumentField("field_2", Collections.singletonList(field2Values.get(i))));
hit.fields(fields); hit.fields(fields);
hits.add(hit); hits.add(hit);
} }
@ -519,4 +526,4 @@ public class ScrollDataExtractorTests extends ESTestCase {
return reader.lines().collect(Collectors.joining("\n")); return reader.lines().collect(Collectors.joining("\n"));
} }
} }
} }

View File

@ -16,6 +16,7 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.SecuritySettingsSource; import org.elasticsearch.test.SecuritySettingsSource;
import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.notifications.Auditor;
import org.elasticsearch.xpack.test.rest.XPackRestTestHelper; import org.elasticsearch.xpack.test.rest.XPackRestTestHelper;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
@ -24,8 +25,10 @@ import java.io.BufferedReader;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader; import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -39,6 +42,8 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING); basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING);
private static final String BASIC_AUTH_VALUE_ML_ADMIN = private static final String BASIC_AUTH_VALUE_ML_ADMIN =
basicAuthHeaderValue("ml_admin", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING); basicAuthHeaderValue("ml_admin", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING);
private static final String BASIC_AUTH_VALUE_ML_ADMIN_WITH_SOME_DATA_ACCESS =
basicAuthHeaderValue("ml_admin_plus_data", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING);
@Override @Override
protected Settings restClientSettings() { protected Settings restClientSettings() {
@ -50,25 +55,39 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
return true; return true;
} }
private void setupUser() throws IOException { private void setupDataAccessRole(String index) throws IOException {
String password = new String(SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING.getChars()); String json = "{"
+ " \"indices\" : ["
// This user has admin rights on machine learning, but (importantly for the tests) no + " { \"names\": [\"" + index + "\"], \"privileges\": [\"read\"] }"
// rights on any of the data indexes + " ]"
String user = "{"
+ " \"password\" : \"" + password + "\","
+ " \"roles\" : [ \"machine_learning_admin\" ]"
+ "}"; + "}";
client().performRequest("put", "_xpack/security/user/ml_admin", Collections.emptyMap(), client().performRequest("put", "_xpack/security/role/test_data_access", Collections.emptyMap(),
new StringEntity(user, ContentType.APPLICATION_JSON)); new StringEntity(json, ContentType.APPLICATION_JSON));
}
private void setupUser(String user, List<String> roles) throws IOException {
String password = new String(SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING.getChars());
String json = "{"
+ " \"password\" : \"" + password + "\","
+ " \"roles\" : [ " + roles.stream().map(unquoted -> "\"" + unquoted + "\"").collect(Collectors.joining(", ")) + " ]"
+ "}";
client().performRequest("put", "_xpack/security/user/" + user, Collections.emptyMap(),
new StringEntity(json, ContentType.APPLICATION_JSON));
} }
@Before @Before
public void setUpData() throws Exception { public void setUpData() throws Exception {
setupUser(); setupDataAccessRole("network-data");
// This user has admin rights on machine learning, but (importantly for the tests) no rights
// on any of the data indexes
setupUser("ml_admin", Collections.singletonList("machine_learning_admin"));
// This user has admin rights on machine learning, and read access to the network-data index
setupUser("ml_admin_plus_data", Arrays.asList("machine_learning_admin", "test_data_access"));
addAirlineData(); addAirlineData();
addNetworkData(); addNetworkData("network-data");
} }
private void addAirlineData() throws IOException { private void addAirlineData() throws IOException {
@ -221,7 +240,7 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
client().performRequest("post", "_refresh"); client().performRequest("post", "_refresh");
} }
private void addNetworkData() throws IOException { private void addNetworkData(String index) throws IOException {
// Create index with source = enabled, doc_values = enabled, stored = false + multi-field // Create index with source = enabled, doc_values = enabled, stored = false + multi-field
String mappings = "{" String mappings = "{"
@ -241,19 +260,19 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
+ " }" + " }"
+ " }" + " }"
+ "}"; + "}";
client().performRequest("put", "network-data", Collections.emptyMap(), new StringEntity(mappings, ContentType.APPLICATION_JSON)); client().performRequest("put", index, Collections.emptyMap(), new StringEntity(mappings, ContentType.APPLICATION_JSON));
String docTemplate = "{\"timestamp\":%d,\"host\":\"%s\",\"network_bytes_out\":%d}"; String docTemplate = "{\"timestamp\":%d,\"host\":\"%s\",\"network_bytes_out\":%d}";
Date date = new Date(1464739200735L); Date date = new Date(1464739200735L);
for (int i=0; i<120; i++) { for (int i=0; i<120; i++) {
long byteCount = randomNonNegativeLong(); long byteCount = randomNonNegativeLong();
String jsonDoc = String.format(Locale.ROOT, docTemplate, date.getTime(), "hostA", byteCount); String jsonDoc = String.format(Locale.ROOT, docTemplate, date.getTime(), "hostA", byteCount);
client().performRequest("post", "network-data/doc", Collections.emptyMap(), client().performRequest("post", index + "/doc", Collections.emptyMap(),
new StringEntity(jsonDoc, ContentType.APPLICATION_JSON)); new StringEntity(jsonDoc, ContentType.APPLICATION_JSON));
byteCount = randomNonNegativeLong(); byteCount = randomNonNegativeLong();
jsonDoc = String.format(Locale.ROOT, docTemplate, date.getTime(), "hostB", byteCount); jsonDoc = String.format(Locale.ROOT, docTemplate, date.getTime(), "hostB", byteCount);
client().performRequest("post", "network-data/doc", Collections.emptyMap(), client().performRequest("post", index + "/doc", Collections.emptyMap(),
new StringEntity(jsonDoc, ContentType.APPLICATION_JSON)); new StringEntity(jsonDoc, ContentType.APPLICATION_JSON));
date = new Date(date.getTime() + 10_000); date = new Date(date.getTime() + 10_000);
@ -263,7 +282,6 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
client().performRequest("post", "_refresh"); client().performRequest("post", "_refresh");
} }
public void testLookbackOnlyWithMixedTypes() throws Exception { public void testLookbackOnlyWithMixedTypes() throws Exception {
new LookbackOnlyTestHelper("test-lookback-only-with-mixed-types", "airline-data") new LookbackOnlyTestHelper("test-lookback-only-with-mixed-types", "airline-data")
.setShouldSucceedProcessing(true).execute(); .setShouldSucceedProcessing(true).execute();
@ -494,6 +512,52 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
assertThat(jobStatsResponseAsString, containsString("\"processed_record_count\":240")); assertThat(jobStatsResponseAsString, containsString("\"processed_record_count\":240"));
} }
public void testLookbackWithoutPermissions() throws Exception {
String jobId = "permission-test-network-job";
String job = "{\"analysis_config\" :{\"bucket_span\":\"300s\","
+ "\"summary_count_field_name\":\"doc_count\","
+ "\"detectors\":[{\"function\":\"mean\",\"field_name\":\"bytes-delta\",\"by_field_name\":\"hostname\"}]},"
+ "\"data_description\" : {\"time_field\":\"timestamp\"}"
+ "}";
client().performRequest("put", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId, Collections.emptyMap(),
new StringEntity(job, ContentType.APPLICATION_JSON));
String datafeedId = "datafeed-" + jobId;
String aggregations =
"{\"hostname\": {\"terms\" : {\"field\": \"host.keyword\", \"size\":10},"
+ "\"aggs\": {\"buckets\": {\"date_histogram\":{\"field\":\"timestamp\",\"interval\":\"5s\"},"
+ "\"aggs\": {\"timestamp\":{\"max\":{\"field\":\"timestamp\"}},"
+ "\"bytes-delta\":{\"derivative\":{\"buckets_path\":\"avg_bytes_out\"}},"
+ "\"avg_bytes_out\":{\"avg\":{\"field\":\"network_bytes_out\"}} }}}}}";
// At the time we create the datafeed the user can access the network-data index that we have access to
new DatafeedBuilder(datafeedId, jobId, "network-data", "doc")
.setAggregations(aggregations)
.setChunkingTimespan("300s")
.setAuthHeader(BASIC_AUTH_VALUE_ML_ADMIN_WITH_SOME_DATA_ACCESS)
.build();
// Change the role so that the user can no longer access network-data
setupDataAccessRole("some-other-data");
openJob(client(), jobId);
startDatafeedAndWaitUntilStopped(datafeedId, BASIC_AUTH_VALUE_ML_ADMIN_WITH_SOME_DATA_ACCESS);
waitUntilJobIsClosed(jobId);
Response jobStatsResponse = client().performRequest("get", MachineLearning.BASE_PATH + "anomaly_detectors/" + jobId + "/_stats");
String jobStatsResponseAsString = responseEntityToString(jobStatsResponse);
// We expect that no data made it through to the job
assertThat(jobStatsResponseAsString, containsString("\"input_record_count\":0"));
assertThat(jobStatsResponseAsString, containsString("\"processed_record_count\":0"));
// There should be a notification saying that there was a problem extracting data
client().performRequest("post", "_refresh");
Response notificationsResponse = client().performRequest("get", Auditor.NOTIFICATIONS_INDEX + "/_search?q=job_id:" + jobId);
String notificationsResponseAsString = responseEntityToString(notificationsResponse);
assertThat(notificationsResponseAsString, containsString("\"message\":\"Datafeed is encountering errors extracting data: " +
"action [indices:data/read/search] is unauthorized for user [ml_admin_plus_data]\""));
}
public void testLookbackWithPipelineBucketAgg() throws Exception { public void testLookbackWithPipelineBucketAgg() throws Exception {
String jobId = "pipeline-bucket-agg-job"; String jobId = "pipeline-bucket-agg-job";
String job = "{\"analysis_config\" :{\"bucket_span\":\"1h\"," String job = "{\"analysis_config\" :{\"bucket_span\":\"1h\","
@ -665,10 +729,14 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
assertThat(jobStatsResponseAsString, containsString("\"missing_field_count\":0")); assertThat(jobStatsResponseAsString, containsString("\"missing_field_count\":0"));
} }
} }
private void startDatafeedAndWaitUntilStopped(String datafeedId) throws Exception { private void startDatafeedAndWaitUntilStopped(String datafeedId) throws Exception {
startDatafeedAndWaitUntilStopped(datafeedId, BASIC_AUTH_VALUE_SUPER_USER);
}
private void startDatafeedAndWaitUntilStopped(String datafeedId, String authHeader) throws Exception {
Response startDatafeedRequest = client().performRequest("post", Response startDatafeedRequest = client().performRequest("post",
MachineLearning.BASE_PATH + "datafeeds/" + datafeedId + "/_start?start=2016-06-01T00:00:00Z&end=2016-06-02T00:00:00Z"); MachineLearning.BASE_PATH + "datafeeds/" + datafeedId + "/_start?start=2016-06-01T00:00:00Z&end=2016-06-02T00:00:00Z",
new BasicHeader("Authorization", authHeader));
assertThat(startDatafeedRequest.getStatusLine().getStatusCode(), equalTo(200)); assertThat(startDatafeedRequest.getStatusLine().getStatusCode(), equalTo(200));
assertThat(responseEntityToString(startDatafeedRequest), equalTo("{\"started\":true}")); assertThat(responseEntityToString(startDatafeedRequest), equalTo("{\"started\":true}"));
assertBusy(() -> { assertBusy(() -> {
@ -763,9 +831,9 @@ public class DatafeedJobsRestIT extends ESRestTestCase {
} }
DatafeedBuilder setChunkingTimespan(String timespan) { DatafeedBuilder setChunkingTimespan(String timespan) {
chunkingTimespan = timespan; chunkingTimespan = timespan;
return this; return this;
} }
Response build() throws IOException { Response build() throws IOException {
String datafeedConfig = "{" String datafeedConfig = "{"

View File

@ -7,7 +7,7 @@ minimal:
# Give all users involved in these tests access to the indices where the data to # Give all users involved in these tests access to the indices where the data to
# be analyzed is stored, because the ML roles alone do not provide access to # be analyzed is stored, because the ML roles alone do not provide access to
# non-ML indices # non-ML indices
- names: [ 'airline-data', 'index-foo', 'unavailable-data' ] - names: [ 'airline-data', 'index-*', 'unavailable-data', 'utopia' ]
privileges: privileges:
- indices:admin/create - indices:admin/create
- indices:admin/refresh - indices:admin/refresh