[ML] Check licence when datafeeds use cross cluster search (#31247)
This change prevents a datafeed using cross cluster search from starting if the remote cluster does not have x-pack installed and a sufficient license. The check is made only when starting a datafeed.
This commit is contained in:
parent
7199d5f0e6
commit
88f44a9f66
|
@ -40,7 +40,6 @@ import org.joda.time.DateTimeZone;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.TimeZone;
|
||||
|
@ -193,11 +192,11 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase<DatafeedCon
|
|||
|
||||
public void testDefaultQueryDelay() {
|
||||
DatafeedConfig.Builder feedBuilder1 = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
feedBuilder1.setIndices(Arrays.asList("foo"));
|
||||
feedBuilder1.setIndices(Collections.singletonList("foo"));
|
||||
DatafeedConfig.Builder feedBuilder2 = new DatafeedConfig.Builder("datafeed2", "job1");
|
||||
feedBuilder2.setIndices(Arrays.asList("foo"));
|
||||
feedBuilder2.setIndices(Collections.singletonList("foo"));
|
||||
DatafeedConfig.Builder feedBuilder3 = new DatafeedConfig.Builder("datafeed3", "job2");
|
||||
feedBuilder3.setIndices(Arrays.asList("foo"));
|
||||
feedBuilder3.setIndices(Collections.singletonList("foo"));
|
||||
DatafeedConfig feed1 = feedBuilder1.build();
|
||||
DatafeedConfig feed2 = feedBuilder2.build();
|
||||
DatafeedConfig feed3 = feedBuilder3.build();
|
||||
|
@ -208,19 +207,19 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase<DatafeedCon
|
|||
assertThat(feed1.getQueryDelay(), not(equalTo(feed3.getQueryDelay())));
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenNullIndices() throws IOException {
|
||||
public void testCheckValid_GivenNullIndices() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
expectThrows(IllegalArgumentException.class, () -> conf.setIndices(null));
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenEmptyIndices() throws IOException {
|
||||
public void testCheckValid_GivenEmptyIndices() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
conf.setIndices(Collections.emptyList());
|
||||
ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, conf::build);
|
||||
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[]"), e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenIndicesContainsOnlyNulls() throws IOException {
|
||||
public void testCheckValid_GivenIndicesContainsOnlyNulls() {
|
||||
List<String> indices = new ArrayList<>();
|
||||
indices.add(null);
|
||||
indices.add(null);
|
||||
|
@ -230,7 +229,7 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase<DatafeedCon
|
|||
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[null, null]"), e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenIndicesContainsOnlyEmptyStrings() throws IOException {
|
||||
public void testCheckValid_GivenIndicesContainsOnlyEmptyStrings() {
|
||||
List<String> indices = new ArrayList<>();
|
||||
indices.add("");
|
||||
indices.add("");
|
||||
|
@ -240,27 +239,27 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase<DatafeedCon
|
|||
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[, ]"), e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenNegativeQueryDelay() throws IOException {
|
||||
public void testCheckValid_GivenNegativeQueryDelay() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class,
|
||||
() -> conf.setQueryDelay(TimeValue.timeValueMillis(-10)));
|
||||
assertEquals("query_delay cannot be less than 0. Value = -10", e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenZeroFrequency() throws IOException {
|
||||
public void testCheckValid_GivenZeroFrequency() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class, () -> conf.setFrequency(TimeValue.ZERO));
|
||||
assertEquals("frequency cannot be less or equal than 0. Value = 0s", e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenNegativeFrequency() throws IOException {
|
||||
public void testCheckValid_GivenNegativeFrequency() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class,
|
||||
() -> conf.setFrequency(TimeValue.timeValueMinutes(-1)));
|
||||
assertEquals("frequency cannot be less or equal than 0. Value = -1", e.getMessage());
|
||||
}
|
||||
|
||||
public void testCheckValid_GivenNegativeScrollSize() throws IOException {
|
||||
public void testCheckValid_GivenNegativeScrollSize() {
|
||||
DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1");
|
||||
ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, () -> conf.setScrollSize(-1000));
|
||||
assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "scroll_size", -1000L), e.getMessage());
|
||||
|
@ -414,7 +413,7 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase<DatafeedCon
|
|||
|
||||
public void testDefaultFrequency_GivenNoAggregations() {
|
||||
DatafeedConfig.Builder datafeedBuilder = new DatafeedConfig.Builder("feed", "job");
|
||||
datafeedBuilder.setIndices(Arrays.asList("my_index"));
|
||||
datafeedBuilder.setIndices(Collections.singletonList("my_index"));
|
||||
DatafeedConfig datafeed = datafeedBuilder.build();
|
||||
|
||||
assertEquals(TimeValue.timeValueMinutes(1), datafeed.defaultFrequency(TimeValue.timeValueSeconds(1)));
|
||||
|
|
|
@ -43,10 +43,12 @@ import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
|
|||
import org.elasticsearch.persistent.PersistentTasksExecutor;
|
||||
import org.elasticsearch.persistent.PersistentTasksService;
|
||||
import org.elasticsearch.xpack.ml.MachineLearning;
|
||||
import org.elasticsearch.xpack.ml.datafeed.MlRemoteLicenseChecker;
|
||||
import org.elasticsearch.xpack.ml.datafeed.DatafeedManager;
|
||||
import org.elasticsearch.xpack.ml.datafeed.DatafeedNodeSelector;
|
||||
import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
|
@ -111,23 +113,25 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
|
|||
ActionListener<StartDatafeedAction.Response> listener) {
|
||||
StartDatafeedAction.DatafeedParams params = request.getParams();
|
||||
if (licenseState.isMachineLearningAllowed()) {
|
||||
ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>> finalListener =
|
||||
new ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>>() {
|
||||
@Override
|
||||
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask) {
|
||||
waitForDatafeedStarted(persistentTask.getId(), params, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
if (e instanceof ResourceAlreadyExistsException) {
|
||||
logger.debug("datafeed already started", e);
|
||||
e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() +
|
||||
"] because it has already been started", RestStatus.CONFLICT);
|
||||
}
|
||||
listener.onFailure(e);
|
||||
}
|
||||
};
|
||||
ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>> waitForTaskListener =
|
||||
new ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>>() {
|
||||
@Override
|
||||
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>
|
||||
persistentTask) {
|
||||
waitForDatafeedStarted(persistentTask.getId(), params, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
if (e instanceof ResourceAlreadyExistsException) {
|
||||
logger.debug("datafeed already started", e);
|
||||
e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() +
|
||||
"] because it has already been started", RestStatus.CONFLICT);
|
||||
}
|
||||
listener.onFailure(e);
|
||||
}
|
||||
};
|
||||
|
||||
// Verify data extractor factory can be created, then start persistent task
|
||||
MlMetadata mlMetadata = MlMetadata.getMlMetadata(state);
|
||||
|
@ -135,16 +139,39 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
|
|||
validate(params.getDatafeedId(), mlMetadata, tasks);
|
||||
DatafeedConfig datafeed = mlMetadata.getDatafeed(params.getDatafeedId());
|
||||
Job job = mlMetadata.getJobs().get(datafeed.getJobId());
|
||||
DataExtractorFactory.create(client, datafeed, job, ActionListener.wrap(
|
||||
dataExtractorFactory ->
|
||||
persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()),
|
||||
StartDatafeedAction.TASK_NAME, params, finalListener)
|
||||
, listener::onFailure));
|
||||
|
||||
if (MlRemoteLicenseChecker.containsRemoteIndex(datafeed.getIndices())) {
|
||||
MlRemoteLicenseChecker remoteLicenseChecker = new MlRemoteLicenseChecker(client);
|
||||
remoteLicenseChecker.checkRemoteClusterLicenses(MlRemoteLicenseChecker.remoteClusterNames(datafeed.getIndices()),
|
||||
ActionListener.wrap(
|
||||
response -> {
|
||||
if (response.isViolated()) {
|
||||
listener.onFailure(createUnlicensedError(datafeed.getId(), response));
|
||||
} else {
|
||||
createDataExtractor(job, datafeed, params, waitForTaskListener);
|
||||
}
|
||||
},
|
||||
e -> listener.onFailure(createUnknownLicenseError(datafeed.getId(),
|
||||
MlRemoteLicenseChecker.remoteIndices(datafeed.getIndices()), e))
|
||||
));
|
||||
} else {
|
||||
createDataExtractor(job, datafeed, params, waitForTaskListener);
|
||||
}
|
||||
} else {
|
||||
listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING));
|
||||
}
|
||||
}
|
||||
|
||||
private void createDataExtractor(Job job, DatafeedConfig datafeed, StartDatafeedAction.DatafeedParams params,
|
||||
ActionListener<PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>>
|
||||
listener) {
|
||||
DataExtractorFactory.create(client, datafeed, job, ActionListener.wrap(
|
||||
dataExtractorFactory ->
|
||||
persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()),
|
||||
StartDatafeedAction.TASK_NAME, params, listener)
|
||||
, listener::onFailure));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected ClusterBlockException checkBlock(StartDatafeedAction.Request request, ClusterState state) {
|
||||
// We only delegate here to PersistentTasksService, but if there is a metadata writeblock,
|
||||
|
@ -158,28 +185,29 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
|
|||
DatafeedPredicate predicate = new DatafeedPredicate();
|
||||
persistentTasksService.waitForPersistentTaskCondition(taskId, predicate, params.getTimeout(),
|
||||
new PersistentTasksService.WaitForPersistentTaskListener<StartDatafeedAction.DatafeedParams>() {
|
||||
@Override
|
||||
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask) {
|
||||
if (predicate.exception != null) {
|
||||
// We want to return to the caller without leaving an unassigned persistent task, to match
|
||||
// what would have happened if the error had been detected in the "fast fail" validation
|
||||
cancelDatafeedStart(persistentTask, predicate.exception, listener);
|
||||
} else {
|
||||
listener.onResponse(new StartDatafeedAction.Response(true));
|
||||
}
|
||||
}
|
||||
@Override
|
||||
public void onResponse(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams>
|
||||
persistentTask) {
|
||||
if (predicate.exception != null) {
|
||||
// We want to return to the caller without leaving an unassigned persistent task, to match
|
||||
// what would have happened if the error had been detected in the "fast fail" validation
|
||||
cancelDatafeedStart(persistentTask, predicate.exception, listener);
|
||||
} else {
|
||||
listener.onResponse(new StartDatafeedAction.Response(true));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
listener.onFailure(e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onTimeout(TimeValue timeout) {
|
||||
listener.onFailure(new ElasticsearchException("Starting datafeed ["
|
||||
+ params.getDatafeedId() + "] timed out after [" + timeout + "]"));
|
||||
}
|
||||
});
|
||||
@Override
|
||||
public void onTimeout(TimeValue timeout) {
|
||||
listener.onFailure(new ElasticsearchException("Starting datafeed ["
|
||||
+ params.getDatafeedId() + "] timed out after [" + timeout + "]"));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void cancelDatafeedStart(PersistentTasksCustomMetaData.PersistentTask<StartDatafeedAction.DatafeedParams> persistentTask,
|
||||
|
@ -203,6 +231,25 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction<Star
|
|||
);
|
||||
}
|
||||
|
||||
private ElasticsearchStatusException createUnlicensedError(String datafeedId,
|
||||
MlRemoteLicenseChecker.LicenseViolation licenseViolation) {
|
||||
String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use "
|
||||
+ "indices on a remote cluster [" + licenseViolation.get().getClusterName()
|
||||
+ "] that is not licensed for Machine Learning. "
|
||||
+ MlRemoteLicenseChecker.buildErrorMessage(licenseViolation.get());
|
||||
|
||||
return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST);
|
||||
}
|
||||
|
||||
private ElasticsearchStatusException createUnknownLicenseError(String datafeedId, List<String> remoteIndices,
|
||||
Exception cause) {
|
||||
String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use"
|
||||
+ " indices on a remote cluster " + remoteIndices
|
||||
+ " but the license type could not be verified";
|
||||
|
||||
return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, new Exception(cause.getMessage()));
|
||||
}
|
||||
|
||||
public static class StartDatafeedPersistentTasksExecutor extends PersistentTasksExecutor<StartDatafeedAction.DatafeedParams> {
|
||||
private final DatafeedManager datafeedManager;
|
||||
private final IndexNameExpressionResolver resolver;
|
||||
|
|
|
@ -91,7 +91,7 @@ public class DatafeedNodeSelector {
|
|||
List<String> indices = datafeed.getIndices();
|
||||
for (String index : indices) {
|
||||
|
||||
if (isRemoteIndex(index)) {
|
||||
if (MlRemoteLicenseChecker.isRemoteIndex(index)) {
|
||||
// We cannot verify remote indices
|
||||
continue;
|
||||
}
|
||||
|
@ -122,10 +122,6 @@ public class DatafeedNodeSelector {
|
|||
return null;
|
||||
}
|
||||
|
||||
private boolean isRemoteIndex(String index) {
|
||||
return index.indexOf(':') != -1;
|
||||
}
|
||||
|
||||
private static class AssignmentFailure {
|
||||
private final String reason;
|
||||
private final boolean isCriticalForTaskCreation;
|
||||
|
|
|
@ -0,0 +1,192 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.datafeed;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.Nullable;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.license.XPackInfoResponse;
|
||||
import org.elasticsearch.transport.ActionNotFoundTransportException;
|
||||
import org.elasticsearch.transport.RemoteClusterAware;
|
||||
import org.elasticsearch.xpack.core.action.XPackInfoAction;
|
||||
import org.elasticsearch.xpack.core.action.XPackInfoRequest;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* ML datafeeds can use cross cluster search to access data in a remote cluster.
|
||||
* The remote cluster should be licenced for ML this class performs that check
|
||||
* using the _xpack (info) endpoint.
|
||||
*/
|
||||
public class MlRemoteLicenseChecker {
|
||||
|
||||
private final Client client;
|
||||
|
||||
public static class RemoteClusterLicenseInfo {
|
||||
private final String clusterName;
|
||||
private final XPackInfoResponse.LicenseInfo licenseInfo;
|
||||
|
||||
RemoteClusterLicenseInfo(String clusterName, XPackInfoResponse.LicenseInfo licenseInfo) {
|
||||
this.clusterName = clusterName;
|
||||
this.licenseInfo = licenseInfo;
|
||||
}
|
||||
|
||||
public String getClusterName() {
|
||||
return clusterName;
|
||||
}
|
||||
|
||||
public XPackInfoResponse.LicenseInfo getLicenseInfo() {
|
||||
return licenseInfo;
|
||||
}
|
||||
}
|
||||
|
||||
public class LicenseViolation {
|
||||
private final RemoteClusterLicenseInfo licenseInfo;
|
||||
|
||||
private LicenseViolation(@Nullable RemoteClusterLicenseInfo licenseInfo) {
|
||||
this.licenseInfo = licenseInfo;
|
||||
}
|
||||
|
||||
public boolean isViolated() {
|
||||
return licenseInfo != null;
|
||||
}
|
||||
|
||||
public RemoteClusterLicenseInfo get() {
|
||||
return licenseInfo;
|
||||
}
|
||||
}
|
||||
|
||||
public MlRemoteLicenseChecker(Client client) {
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check each cluster is licensed for ML.
|
||||
* This function evaluates lazily and will terminate when the first cluster
|
||||
* that is not licensed is found or an error occurs.
|
||||
*
|
||||
* @param clusterNames List of remote cluster names
|
||||
* @param listener Response listener
|
||||
*/
|
||||
public void checkRemoteClusterLicenses(List<String> clusterNames, ActionListener<LicenseViolation> listener) {
|
||||
final Iterator<String> itr = clusterNames.iterator();
|
||||
if (itr.hasNext() == false) {
|
||||
listener.onResponse(new LicenseViolation(null));
|
||||
return;
|
||||
}
|
||||
|
||||
final AtomicReference<String> clusterName = new AtomicReference<>(itr.next());
|
||||
|
||||
ActionListener<XPackInfoResponse> infoListener = new ActionListener<XPackInfoResponse>() {
|
||||
@Override
|
||||
public void onResponse(XPackInfoResponse xPackInfoResponse) {
|
||||
if (licenseSupportsML(xPackInfoResponse.getLicenseInfo()) == false) {
|
||||
listener.onResponse(new LicenseViolation(
|
||||
new RemoteClusterLicenseInfo(clusterName.get(), xPackInfoResponse.getLicenseInfo())));
|
||||
return;
|
||||
}
|
||||
|
||||
if (itr.hasNext()) {
|
||||
clusterName.set(itr.next());
|
||||
remoteClusterLicense(clusterName.get(), this);
|
||||
} else {
|
||||
listener.onResponse(new LicenseViolation(null));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
String message = "Could not determine the X-Pack licence type for cluster [" + clusterName.get() + "]";
|
||||
if (e instanceof ActionNotFoundTransportException) {
|
||||
// This is likely to be because x-pack is not installed in the target cluster
|
||||
message += ". Is X-Pack installed on the target cluster?";
|
||||
}
|
||||
listener.onFailure(new ElasticsearchException(message, e));
|
||||
}
|
||||
};
|
||||
|
||||
remoteClusterLicense(clusterName.get(), infoListener);
|
||||
}
|
||||
|
||||
private void remoteClusterLicense(String clusterName, ActionListener<XPackInfoResponse> listener) {
|
||||
Client remoteClusterClient = client.getRemoteClusterClient(clusterName);
|
||||
ThreadContext threadContext = remoteClusterClient.threadPool().getThreadContext();
|
||||
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
|
||||
// we stash any context here since this is an internal execution and should not leak any
|
||||
// existing context information.
|
||||
threadContext.markAsSystemContext();
|
||||
|
||||
XPackInfoRequest request = new XPackInfoRequest();
|
||||
request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE));
|
||||
remoteClusterClient.execute(XPackInfoAction.INSTANCE, request, listener);
|
||||
}
|
||||
}
|
||||
|
||||
static boolean licenseSupportsML(XPackInfoResponse.LicenseInfo licenseInfo) {
|
||||
License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode());
|
||||
return licenseInfo.getStatus() == License.Status.ACTIVE &&
|
||||
(mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL);
|
||||
}
|
||||
|
||||
public static boolean isRemoteIndex(String index) {
|
||||
return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1;
|
||||
}
|
||||
|
||||
public static boolean containsRemoteIndex(List<String> indices) {
|
||||
return indices.stream().anyMatch(MlRemoteLicenseChecker::isRemoteIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any remote indices used in cross cluster search.
|
||||
* Remote indices are of the form {@code cluster_name:index_name}
|
||||
* @return List of remote cluster indices
|
||||
*/
|
||||
public static List<String> remoteIndices(List<String> indices) {
|
||||
return indices.stream().filter(MlRemoteLicenseChecker::isRemoteIndex).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the list of remote cluster names from the list of indices.
|
||||
* @param indices List of indices. Remote cluster indices are prefixed
|
||||
* with {@code cluster-name:}
|
||||
* @return Every cluster name found in {@code indices}
|
||||
*/
|
||||
public static List<String> remoteClusterNames(List<String> indices) {
|
||||
return indices.stream()
|
||||
.filter(MlRemoteLicenseChecker::isRemoteIndex)
|
||||
.map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR)))
|
||||
.distinct()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static String buildErrorMessage(RemoteClusterLicenseInfo clusterLicenseInfo) {
|
||||
StringBuilder error = new StringBuilder();
|
||||
if (clusterLicenseInfo.licenseInfo.getStatus() != License.Status.ACTIVE) {
|
||||
error.append("The license on cluster [").append(clusterLicenseInfo.clusterName)
|
||||
.append("] is not active. ");
|
||||
} else {
|
||||
License.OperationMode mode = License.OperationMode.resolve(clusterLicenseInfo.licenseInfo.getMode());
|
||||
if (mode != License.OperationMode.PLATINUM && mode != License.OperationMode.TRIAL) {
|
||||
error.append("The license mode [").append(mode)
|
||||
.append("] on cluster [")
|
||||
.append(clusterLicenseInfo.clusterName)
|
||||
.append("] does not enable Machine Learning. ");
|
||||
}
|
||||
}
|
||||
|
||||
error.append(Strings.toString(clusterLicenseInfo.licenseInfo));
|
||||
return error.toString();
|
||||
}
|
||||
}
|
|
@ -117,7 +117,7 @@ public interface AutodetectProcess extends Closeable {
|
|||
|
||||
/**
|
||||
* Ask the job to start persisting model state in the background
|
||||
* @throws IOException
|
||||
* @throws IOException If writing the request fails
|
||||
*/
|
||||
void persistJob() throws IOException;
|
||||
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License;
|
||||
* you may not use this file except in compliance with the Elastic License.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.ml.datafeed;
|
||||
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.client.Client;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.common.util.concurrent.ThreadContext;
|
||||
import org.elasticsearch.license.License;
|
||||
import org.elasticsearch.license.XPackInfoResponse;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.action.XPackInfoAction;
|
||||
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.hamcrest.Matchers.contains;
|
||||
import static org.hamcrest.Matchers.containsInAnyOrder;
|
||||
import static org.hamcrest.Matchers.empty;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.Matchers.any;
|
||||
import static org.mockito.Matchers.anyString;
|
||||
import static org.mockito.Matchers.same;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.times;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
public class MlRemoteLicenseCheckerTests extends ESTestCase {
|
||||
|
||||
public void testIsRemoteIndex() {
|
||||
List<String> indices = Arrays.asList("local-index1", "local-index2");
|
||||
assertFalse(MlRemoteLicenseChecker.containsRemoteIndex(indices));
|
||||
indices = Arrays.asList("local-index1", "remote-cluster:remote-index2");
|
||||
assertTrue(MlRemoteLicenseChecker.containsRemoteIndex(indices));
|
||||
}
|
||||
|
||||
public void testRemoteIndices() {
|
||||
List<String> indices = Collections.singletonList("local-index");
|
||||
assertThat(MlRemoteLicenseChecker.remoteIndices(indices), is(empty()));
|
||||
indices = Arrays.asList("local-index", "remote-cluster:index1", "local-index2", "remote-cluster2:index1");
|
||||
assertThat(MlRemoteLicenseChecker.remoteIndices(indices), containsInAnyOrder("remote-cluster:index1", "remote-cluster2:index1"));
|
||||
}
|
||||
|
||||
public void testRemoteClusterNames() {
|
||||
List<String> indices = Arrays.asList("local-index1", "local-index2");
|
||||
assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), empty());
|
||||
indices = Arrays.asList("local-index1", "remote-cluster1:remote-index2");
|
||||
assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1"));
|
||||
indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1");
|
||||
assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2"));
|
||||
indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1", "remote-cluster2:index2");
|
||||
assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2"));
|
||||
}
|
||||
|
||||
public void testLicenseSupportsML() {
|
||||
XPackInfoResponse.LicenseInfo licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial",
|
||||
License.Status.ACTIVE, randomNonNegativeLong());
|
||||
assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
|
||||
|
||||
licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", License.Status.EXPIRED, randomNonNegativeLong());
|
||||
assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
|
||||
|
||||
licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "GOLD", "GOLD", License.Status.ACTIVE, randomNonNegativeLong());
|
||||
assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
|
||||
|
||||
licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.ACTIVE, randomNonNegativeLong());
|
||||
assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo));
|
||||
}
|
||||
|
||||
public void testCheckRemoteClusterLicenses_givenValidLicenses() {
|
||||
final AtomicInteger index = new AtomicInteger(0);
|
||||
final List<XPackInfoResponse> responses = new ArrayList<>();
|
||||
|
||||
Client client = createMockClient();
|
||||
doAnswer(invocationMock -> {
|
||||
@SuppressWarnings("raw_types")
|
||||
ActionListener listener = (ActionListener) invocationMock.getArguments()[2];
|
||||
listener.onResponse(responses.get(index.getAndIncrement()));
|
||||
return null;
|
||||
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
|
||||
|
||||
|
||||
List<String> remoteClusterNames = Arrays.asList("valid1", "valid2", "valid3");
|
||||
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
|
||||
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
|
||||
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
|
||||
|
||||
MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client);
|
||||
AtomicReference<MlRemoteLicenseChecker.LicenseViolation> licCheckResponse = new AtomicReference<>();
|
||||
|
||||
licenseChecker.checkRemoteClusterLicenses(remoteClusterNames,
|
||||
new ActionListener<MlRemoteLicenseChecker.LicenseViolation>() {
|
||||
@Override
|
||||
public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) {
|
||||
licCheckResponse.set(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
});
|
||||
|
||||
verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any());
|
||||
assertNotNull(licCheckResponse.get());
|
||||
assertFalse(licCheckResponse.get().isViolated());
|
||||
assertNull(licCheckResponse.get().get());
|
||||
}
|
||||
|
||||
public void testCheckRemoteClusterLicenses_givenInvalidLicense() {
|
||||
final AtomicInteger index = new AtomicInteger(0);
|
||||
List<String> remoteClusterNames = Arrays.asList("good", "cluster-with-basic-license", "good2");
|
||||
final List<XPackInfoResponse> responses = new ArrayList<>();
|
||||
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
|
||||
responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null));
|
||||
responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null));
|
||||
|
||||
Client client = createMockClient();
|
||||
doAnswer(invocationMock -> {
|
||||
@SuppressWarnings("raw_types")
|
||||
ActionListener listener = (ActionListener) invocationMock.getArguments()[2];
|
||||
listener.onResponse(responses.get(index.getAndIncrement()));
|
||||
return null;
|
||||
}).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any());
|
||||
|
||||
MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client);
|
||||
AtomicReference<MlRemoteLicenseChecker.LicenseViolation> licCheckResponse = new AtomicReference<>();
|
||||
|
||||
licenseChecker.checkRemoteClusterLicenses(remoteClusterNames,
|
||||
new ActionListener<MlRemoteLicenseChecker.LicenseViolation>() {
|
||||
@Override
|
||||
public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) {
|
||||
licCheckResponse.set(response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(Exception e) {
|
||||
fail(e.getMessage());
|
||||
}
|
||||
});
|
||||
|
||||
verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any());
|
||||
assertNotNull(licCheckResponse.get());
|
||||
assertTrue(licCheckResponse.get().isViolated());
|
||||
assertEquals("cluster-with-basic-license", licCheckResponse.get().get().getClusterName());
|
||||
assertEquals("BASIC", licCheckResponse.get().get().getLicenseInfo().getType());
|
||||
}
|
||||
|
||||
public void testBuildErrorMessage() {
|
||||
XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse();
|
||||
MlRemoteLicenseChecker.RemoteClusterLicenseInfo info =
|
||||
new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence);
|
||||
assertEquals(Strings.toString(platinumLicence), MlRemoteLicenseChecker.buildErrorMessage(info));
|
||||
|
||||
XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse();
|
||||
info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense);
|
||||
String expected = "The license mode [BASIC] on cluster [basic-cluster] does not enable Machine Learning. "
|
||||
+ Strings.toString(basicLicense);
|
||||
assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info));
|
||||
|
||||
XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse();
|
||||
info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense);
|
||||
expected = "The license on cluster [expired-cluster] is not active. " + Strings.toString(expiredLicense);
|
||||
assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info));
|
||||
}
|
||||
|
||||
private Client createMockClient() {
|
||||
Client client = mock(Client.class);
|
||||
ThreadPool threadPool = mock(ThreadPool.class);
|
||||
when(client.threadPool()).thenReturn(threadPool);
|
||||
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
|
||||
when(client.getRemoteClusterClient(anyString())).thenReturn(client);
|
||||
return client;
|
||||
}
|
||||
|
||||
private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() {
|
||||
return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.ACTIVE, randomNonNegativeLong());
|
||||
}
|
||||
|
||||
private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() {
|
||||
return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", License.Status.ACTIVE, randomNonNegativeLong());
|
||||
}
|
||||
|
||||
private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() {
|
||||
return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.EXPIRED, randomNonNegativeLong());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue