From 88f44a9f66b61fb158b3f4605a57909c86037489 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 13 Jun 2018 15:42:18 +0100 Subject: [PATCH] [ML] Check licence when datafeeds use cross cluster search (#31247) This change prevents a datafeed using cross cluster search from starting if the remote cluster does not have x-pack installed and a sufficient license. The check is made only when starting a datafeed. --- .../core/ml/datafeed/DatafeedConfigTests.java | 25 ++- .../action/TransportStartDatafeedAction.java | 129 +++++++---- .../ml/datafeed/DatafeedNodeSelector.java | 6 +- .../ml/datafeed/MlRemoteLicenseChecker.java | 192 +++++++++++++++++ .../process/autodetect/AutodetectProcess.java | 2 +- .../datafeed/MlRemoteLicenseCheckerTests.java | 200 ++++++++++++++++++ 6 files changed, 494 insertions(+), 60 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java index 6aa987fc0e9..d59ef16dfdf 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java @@ -40,7 +40,6 @@ import org.joda.time.DateTimeZone; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.TimeZone; @@ -193,11 +192,11 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase conf.setIndices(null)); } - public void testCheckValid_GivenEmptyIndices() throws IOException { + public void testCheckValid_GivenEmptyIndices() { DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1"); conf.setIndices(Collections.emptyList()); ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, conf::build); assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "indices", "[]"), e.getMessage()); } - public void testCheckValid_GivenIndicesContainsOnlyNulls() throws IOException { + public void testCheckValid_GivenIndicesContainsOnlyNulls() { List indices = new ArrayList<>(); indices.add(null); indices.add(null); @@ -230,7 +229,7 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase indices = new ArrayList<>(); indices.add(""); indices.add(""); @@ -240,27 +239,27 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase conf.setQueryDelay(TimeValue.timeValueMillis(-10))); assertEquals("query_delay cannot be less than 0. Value = -10", e.getMessage()); } - public void testCheckValid_GivenZeroFrequency() throws IOException { + public void testCheckValid_GivenZeroFrequency() { DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1"); IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class, () -> conf.setFrequency(TimeValue.ZERO)); assertEquals("frequency cannot be less or equal than 0. Value = 0s", e.getMessage()); } - public void testCheckValid_GivenNegativeFrequency() throws IOException { + public void testCheckValid_GivenNegativeFrequency() { DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1"); IllegalArgumentException e = ESTestCase.expectThrows(IllegalArgumentException.class, () -> conf.setFrequency(TimeValue.timeValueMinutes(-1))); assertEquals("frequency cannot be less or equal than 0. Value = -1", e.getMessage()); } - public void testCheckValid_GivenNegativeScrollSize() throws IOException { + public void testCheckValid_GivenNegativeScrollSize() { DatafeedConfig.Builder conf = new DatafeedConfig.Builder("datafeed1", "job1"); ElasticsearchException e = ESTestCase.expectThrows(ElasticsearchException.class, () -> conf.setScrollSize(-1000)); assertEquals(Messages.getMessage(Messages.DATAFEED_CONFIG_INVALID_OPTION_VALUE, "scroll_size", -1000L), e.getMessage()); @@ -414,7 +413,7 @@ public class DatafeedConfigTests extends AbstractSerializingTestCase listener) { StartDatafeedAction.DatafeedParams params = request.getParams(); if (licenseState.isMachineLearningAllowed()) { - ActionListener> finalListener = - new ActionListener>() { - @Override - public void onResponse(PersistentTasksCustomMetaData.PersistentTask persistentTask) { - waitForDatafeedStarted(persistentTask.getId(), params, listener); - } - @Override - public void onFailure(Exception e) { - if (e instanceof ResourceAlreadyExistsException) { - logger.debug("datafeed already started", e); - e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() + - "] because it has already been started", RestStatus.CONFLICT); - } - listener.onFailure(e); - } - }; + ActionListener> waitForTaskListener = + new ActionListener>() { + @Override + public void onResponse(PersistentTasksCustomMetaData.PersistentTask + persistentTask) { + waitForDatafeedStarted(persistentTask.getId(), params, listener); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof ResourceAlreadyExistsException) { + logger.debug("datafeed already started", e); + e = new ElasticsearchStatusException("cannot start datafeed [" + params.getDatafeedId() + + "] because it has already been started", RestStatus.CONFLICT); + } + listener.onFailure(e); + } + }; // Verify data extractor factory can be created, then start persistent task MlMetadata mlMetadata = MlMetadata.getMlMetadata(state); @@ -135,16 +139,39 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction - persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()), - StartDatafeedAction.TASK_NAME, params, finalListener) - , listener::onFailure)); + + if (MlRemoteLicenseChecker.containsRemoteIndex(datafeed.getIndices())) { + MlRemoteLicenseChecker remoteLicenseChecker = new MlRemoteLicenseChecker(client); + remoteLicenseChecker.checkRemoteClusterLicenses(MlRemoteLicenseChecker.remoteClusterNames(datafeed.getIndices()), + ActionListener.wrap( + response -> { + if (response.isViolated()) { + listener.onFailure(createUnlicensedError(datafeed.getId(), response)); + } else { + createDataExtractor(job, datafeed, params, waitForTaskListener); + } + }, + e -> listener.onFailure(createUnknownLicenseError(datafeed.getId(), + MlRemoteLicenseChecker.remoteIndices(datafeed.getIndices()), e)) + )); + } else { + createDataExtractor(job, datafeed, params, waitForTaskListener); + } } else { listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); } } + private void createDataExtractor(Job job, DatafeedConfig datafeed, StartDatafeedAction.DatafeedParams params, + ActionListener> + listener) { + DataExtractorFactory.create(client, datafeed, job, ActionListener.wrap( + dataExtractorFactory -> + persistentTasksService.sendStartRequest(MLMetadataField.datafeedTaskId(params.getDatafeedId()), + StartDatafeedAction.TASK_NAME, params, listener) + , listener::onFailure)); + } + @Override protected ClusterBlockException checkBlock(StartDatafeedAction.Request request, ClusterState state) { // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, @@ -158,28 +185,29 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction() { - @Override - public void onResponse(PersistentTasksCustomMetaData.PersistentTask persistentTask) { - if (predicate.exception != null) { - // We want to return to the caller without leaving an unassigned persistent task, to match - // what would have happened if the error had been detected in the "fast fail" validation - cancelDatafeedStart(persistentTask, predicate.exception, listener); - } else { - listener.onResponse(new StartDatafeedAction.Response(true)); - } - } + @Override + public void onResponse(PersistentTasksCustomMetaData.PersistentTask + persistentTask) { + if (predicate.exception != null) { + // We want to return to the caller without leaving an unassigned persistent task, to match + // what would have happened if the error had been detected in the "fast fail" validation + cancelDatafeedStart(persistentTask, predicate.exception, listener); + } else { + listener.onResponse(new StartDatafeedAction.Response(true)); + } + } - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } - @Override - public void onTimeout(TimeValue timeout) { - listener.onFailure(new ElasticsearchException("Starting datafeed [" - + params.getDatafeedId() + "] timed out after [" + timeout + "]")); - } - }); + @Override + public void onTimeout(TimeValue timeout) { + listener.onFailure(new ElasticsearchException("Starting datafeed [" + + params.getDatafeedId() + "] timed out after [" + timeout + "]")); + } + }); } private void cancelDatafeedStart(PersistentTasksCustomMetaData.PersistentTask persistentTask, @@ -203,6 +231,25 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction remoteIndices, + Exception cause) { + String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use" + + " indices on a remote cluster " + remoteIndices + + " but the license type could not be verified"; + + return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, new Exception(cause.getMessage())); + } + public static class StartDatafeedPersistentTasksExecutor extends PersistentTasksExecutor { private final DatafeedManager datafeedManager; private final IndexNameExpressionResolver resolver; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java index 37f9715d094..0eb57ab79be 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java @@ -91,7 +91,7 @@ public class DatafeedNodeSelector { List indices = datafeed.getIndices(); for (String index : indices) { - if (isRemoteIndex(index)) { + if (MlRemoteLicenseChecker.isRemoteIndex(index)) { // We cannot verify remote indices continue; } @@ -122,10 +122,6 @@ public class DatafeedNodeSelector { return null; } - private boolean isRemoteIndex(String index) { - return index.indexOf(':') != -1; - } - private static class AssignmentFailure { private final String reason; private final boolean isCriticalForTaskCreation; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java new file mode 100644 index 00000000000..b55713f6d0a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java @@ -0,0 +1,192 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.datafeed; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.license.License; +import org.elasticsearch.license.XPackInfoResponse; +import org.elasticsearch.transport.ActionNotFoundTransportException; +import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.xpack.core.action.XPackInfoAction; +import org.elasticsearch.xpack.core.action.XPackInfoRequest; + +import java.util.EnumSet; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +/** + * ML datafeeds can use cross cluster search to access data in a remote cluster. + * The remote cluster should be licenced for ML this class performs that check + * using the _xpack (info) endpoint. + */ +public class MlRemoteLicenseChecker { + + private final Client client; + + public static class RemoteClusterLicenseInfo { + private final String clusterName; + private final XPackInfoResponse.LicenseInfo licenseInfo; + + RemoteClusterLicenseInfo(String clusterName, XPackInfoResponse.LicenseInfo licenseInfo) { + this.clusterName = clusterName; + this.licenseInfo = licenseInfo; + } + + public String getClusterName() { + return clusterName; + } + + public XPackInfoResponse.LicenseInfo getLicenseInfo() { + return licenseInfo; + } + } + + public class LicenseViolation { + private final RemoteClusterLicenseInfo licenseInfo; + + private LicenseViolation(@Nullable RemoteClusterLicenseInfo licenseInfo) { + this.licenseInfo = licenseInfo; + } + + public boolean isViolated() { + return licenseInfo != null; + } + + public RemoteClusterLicenseInfo get() { + return licenseInfo; + } + } + + public MlRemoteLicenseChecker(Client client) { + this.client = client; + } + + /** + * Check each cluster is licensed for ML. + * This function evaluates lazily and will terminate when the first cluster + * that is not licensed is found or an error occurs. + * + * @param clusterNames List of remote cluster names + * @param listener Response listener + */ + public void checkRemoteClusterLicenses(List clusterNames, ActionListener listener) { + final Iterator itr = clusterNames.iterator(); + if (itr.hasNext() == false) { + listener.onResponse(new LicenseViolation(null)); + return; + } + + final AtomicReference clusterName = new AtomicReference<>(itr.next()); + + ActionListener infoListener = new ActionListener() { + @Override + public void onResponse(XPackInfoResponse xPackInfoResponse) { + if (licenseSupportsML(xPackInfoResponse.getLicenseInfo()) == false) { + listener.onResponse(new LicenseViolation( + new RemoteClusterLicenseInfo(clusterName.get(), xPackInfoResponse.getLicenseInfo()))); + return; + } + + if (itr.hasNext()) { + clusterName.set(itr.next()); + remoteClusterLicense(clusterName.get(), this); + } else { + listener.onResponse(new LicenseViolation(null)); + } + } + + @Override + public void onFailure(Exception e) { + String message = "Could not determine the X-Pack licence type for cluster [" + clusterName.get() + "]"; + if (e instanceof ActionNotFoundTransportException) { + // This is likely to be because x-pack is not installed in the target cluster + message += ". Is X-Pack installed on the target cluster?"; + } + listener.onFailure(new ElasticsearchException(message, e)); + } + }; + + remoteClusterLicense(clusterName.get(), infoListener); + } + + private void remoteClusterLicense(String clusterName, ActionListener listener) { + Client remoteClusterClient = client.getRemoteClusterClient(clusterName); + ThreadContext threadContext = remoteClusterClient.threadPool().getThreadContext(); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + // we stash any context here since this is an internal execution and should not leak any + // existing context information. + threadContext.markAsSystemContext(); + + XPackInfoRequest request = new XPackInfoRequest(); + request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE)); + remoteClusterClient.execute(XPackInfoAction.INSTANCE, request, listener); + } + } + + static boolean licenseSupportsML(XPackInfoResponse.LicenseInfo licenseInfo) { + License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode()); + return licenseInfo.getStatus() == License.Status.ACTIVE && + (mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL); + } + + public static boolean isRemoteIndex(String index) { + return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1; + } + + public static boolean containsRemoteIndex(List indices) { + return indices.stream().anyMatch(MlRemoteLicenseChecker::isRemoteIndex); + } + + /** + * Get any remote indices used in cross cluster search. + * Remote indices are of the form {@code cluster_name:index_name} + * @return List of remote cluster indices + */ + public static List remoteIndices(List indices) { + return indices.stream().filter(MlRemoteLicenseChecker::isRemoteIndex).collect(Collectors.toList()); + } + + /** + * Extract the list of remote cluster names from the list of indices. + * @param indices List of indices. Remote cluster indices are prefixed + * with {@code cluster-name:} + * @return Every cluster name found in {@code indices} + */ + public static List remoteClusterNames(List indices) { + return indices.stream() + .filter(MlRemoteLicenseChecker::isRemoteIndex) + .map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR))) + .distinct() + .collect(Collectors.toList()); + } + + public static String buildErrorMessage(RemoteClusterLicenseInfo clusterLicenseInfo) { + StringBuilder error = new StringBuilder(); + if (clusterLicenseInfo.licenseInfo.getStatus() != License.Status.ACTIVE) { + error.append("The license on cluster [").append(clusterLicenseInfo.clusterName) + .append("] is not active. "); + } else { + License.OperationMode mode = License.OperationMode.resolve(clusterLicenseInfo.licenseInfo.getMode()); + if (mode != License.OperationMode.PLATINUM && mode != License.OperationMode.TRIAL) { + error.append("The license mode [").append(mode) + .append("] on cluster [") + .append(clusterLicenseInfo.clusterName) + .append("] does not enable Machine Learning. "); + } + } + + error.append(Strings.toString(clusterLicenseInfo.licenseInfo)); + return error.toString(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java index 049880b1ac2..21be815d561 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcess.java @@ -117,7 +117,7 @@ public interface AutodetectProcess extends Closeable { /** * Ask the job to start persisting model state in the background - * @throws IOException + * @throws IOException If writing the request fails */ void persistJob() throws IOException; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java new file mode 100644 index 00000000000..47d4d30a7c6 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.datafeed; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.license.License; +import org.elasticsearch.license.XPackInfoResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.action.XPackInfoAction; +import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class MlRemoteLicenseCheckerTests extends ESTestCase { + + public void testIsRemoteIndex() { + List indices = Arrays.asList("local-index1", "local-index2"); + assertFalse(MlRemoteLicenseChecker.containsRemoteIndex(indices)); + indices = Arrays.asList("local-index1", "remote-cluster:remote-index2"); + assertTrue(MlRemoteLicenseChecker.containsRemoteIndex(indices)); + } + + public void testRemoteIndices() { + List indices = Collections.singletonList("local-index"); + assertThat(MlRemoteLicenseChecker.remoteIndices(indices), is(empty())); + indices = Arrays.asList("local-index", "remote-cluster:index1", "local-index2", "remote-cluster2:index1"); + assertThat(MlRemoteLicenseChecker.remoteIndices(indices), containsInAnyOrder("remote-cluster:index1", "remote-cluster2:index1")); + } + + public void testRemoteClusterNames() { + List indices = Arrays.asList("local-index1", "local-index2"); + assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), empty()); + indices = Arrays.asList("local-index1", "remote-cluster1:remote-index2"); + assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1")); + indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1"); + assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2")); + indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1", "remote-cluster2:index2"); + assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2")); + } + + public void testLicenseSupportsML() { + XPackInfoResponse.LicenseInfo licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", + License.Status.ACTIVE, randomNonNegativeLong()); + assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); + + licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", License.Status.EXPIRED, randomNonNegativeLong()); + assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); + + licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "GOLD", "GOLD", License.Status.ACTIVE, randomNonNegativeLong()); + assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); + + licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.ACTIVE, randomNonNegativeLong()); + assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); + } + + public void testCheckRemoteClusterLicenses_givenValidLicenses() { + final AtomicInteger index = new AtomicInteger(0); + final List responses = new ArrayList<>(); + + Client client = createMockClient(); + doAnswer(invocationMock -> { + @SuppressWarnings("raw_types") + ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + + List remoteClusterNames = Arrays.asList("valid1", "valid2", "valid3"); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client); + AtomicReference licCheckResponse = new AtomicReference<>(); + + licenseChecker.checkRemoteClusterLicenses(remoteClusterNames, + new ActionListener() { + @Override + public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) { + licCheckResponse.set(response); + } + + @Override + public void onFailure(Exception e) { + fail(e.getMessage()); + } + }); + + verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any()); + assertNotNull(licCheckResponse.get()); + assertFalse(licCheckResponse.get().isViolated()); + assertNull(licCheckResponse.get().get()); + } + + public void testCheckRemoteClusterLicenses_givenInvalidLicense() { + final AtomicInteger index = new AtomicInteger(0); + List remoteClusterNames = Arrays.asList("good", "cluster-with-basic-license", "good2"); + final List responses = new ArrayList<>(); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + Client client = createMockClient(); + doAnswer(invocationMock -> { + @SuppressWarnings("raw_types") + ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client); + AtomicReference licCheckResponse = new AtomicReference<>(); + + licenseChecker.checkRemoteClusterLicenses(remoteClusterNames, + new ActionListener() { + @Override + public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) { + licCheckResponse.set(response); + } + + @Override + public void onFailure(Exception e) { + fail(e.getMessage()); + } + }); + + verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any()); + assertNotNull(licCheckResponse.get()); + assertTrue(licCheckResponse.get().isViolated()); + assertEquals("cluster-with-basic-license", licCheckResponse.get().get().getClusterName()); + assertEquals("BASIC", licCheckResponse.get().get().getLicenseInfo().getType()); + } + + public void testBuildErrorMessage() { + XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse(); + MlRemoteLicenseChecker.RemoteClusterLicenseInfo info = + new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence); + assertEquals(Strings.toString(platinumLicence), MlRemoteLicenseChecker.buildErrorMessage(info)); + + XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse(); + info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense); + String expected = "The license mode [BASIC] on cluster [basic-cluster] does not enable Machine Learning. " + + Strings.toString(basicLicense); + assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info)); + + XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse(); + info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense); + expected = "The license on cluster [expired-cluster] is not active. " + Strings.toString(expiredLicense); + assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info)); + } + + private Client createMockClient() { + Client client = mock(Client.class); + ThreadPool threadPool = mock(ThreadPool.class); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + when(client.getRemoteClusterClient(anyString())).thenReturn(client); + return client; + } + + private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.ACTIVE, randomNonNegativeLong()); + } + + private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", License.Status.ACTIVE, randomNonNegativeLong()); + } + + private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", License.Status.EXPIRED, randomNonNegativeLong()); + } +}