From 9050c7e846b6a34e7a33e05418b29f5789154943 Mon Sep 17 00:00:00 2001 From: Jason Tedor Date: Mon, 20 Aug 2018 15:33:29 -0400 Subject: [PATCH] Generalize remote license checker (#32971) Machine learning has baked a remote license checker for use in checking license compatibility of a remote license. This remote license checker has general usage for any feature that relies on a remote cluster. For example, cross-cluster replication will pull changes from a remote cluster and require that the local and remote clusters have platinum licenses. This commit generalizes the remote cluster license check for use in cross-cluster replication. --- .../license/RemoteClusterLicenseChecker.java | 281 ++++++++++++ .../RemoteClusterLicenseCheckerTests.java | 414 ++++++++++++++++++ .../action/TransportStartDatafeedAction.java | 56 ++- .../ml/datafeed/DatafeedNodeSelector.java | 3 +- .../ml/datafeed/MlRemoteLicenseChecker.java | 193 -------- .../TransportStartDatafeedActionTests.java | 3 +- .../datafeed/MlRemoteLicenseCheckerTests.java | 199 --------- 7 files changed, 735 insertions(+), 414 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java new file mode 100644 index 00000000000..043224e357b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/license/RemoteClusterLicenseChecker.java @@ -0,0 +1,281 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.license; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.protocol.xpack.XPackInfoRequest; +import org.elasticsearch.protocol.xpack.XPackInfoResponse; +import org.elasticsearch.protocol.xpack.license.LicenseStatus; +import org.elasticsearch.transport.RemoteClusterAware; +import org.elasticsearch.xpack.core.action.XPackInfoAction; + +import java.util.EnumSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** + * Checks remote clusters for license compatibility with a specified license predicate. + */ +public final class RemoteClusterLicenseChecker { + + /** + * Encapsulates the license info of a remote cluster. + */ + public static final class RemoteClusterLicenseInfo { + + private final String clusterAlias; + + /** + * The alias of the remote cluster. + * + * @return the cluster alias + */ + public String clusterAlias() { + return clusterAlias; + } + + private final XPackInfoResponse.LicenseInfo licenseInfo; + + /** + * The license info of the remote cluster. + * + * @return the license info + */ + public XPackInfoResponse.LicenseInfo licenseInfo() { + return licenseInfo; + } + + RemoteClusterLicenseInfo(final String clusterAlias, final XPackInfoResponse.LicenseInfo licenseInfo) { + this.clusterAlias = clusterAlias; + this.licenseInfo = licenseInfo; + } + + } + + /** + * Encapsulates a remote cluster license check. The check is either successful if the license of the remote cluster is compatible with + * the predicate used to check license compatibility, or the check is a failure. + */ + public static final class LicenseCheck { + + private final RemoteClusterLicenseInfo remoteClusterLicenseInfo; + + /** + * The remote cluster license info. This method should only be invoked if this instance represents a failing license check. + * + * @return the remote cluster license info + */ + public RemoteClusterLicenseInfo remoteClusterLicenseInfo() { + assert isSuccess() == false; + return remoteClusterLicenseInfo; + } + + private static final LicenseCheck SUCCESS = new LicenseCheck(null); + + /** + * A successful license check. + * + * @return a successful license check instance + */ + public static LicenseCheck success() { + return SUCCESS; + } + + /** + * Test if this instance represents a successful license check. + * + * @return true if this instance represents a successful license check, otherwise false + */ + public boolean isSuccess() { + return this == SUCCESS; + } + + /** + * Creates a failing license check encapsulating the specified remote cluster license info. + * + * @param remoteClusterLicenseInfo the remote cluster license info + * @return a failing license check + */ + public static LicenseCheck failure(final RemoteClusterLicenseInfo remoteClusterLicenseInfo) { + return new LicenseCheck(remoteClusterLicenseInfo); + } + + private LicenseCheck(final RemoteClusterLicenseInfo remoteClusterLicenseInfo) { + this.remoteClusterLicenseInfo = remoteClusterLicenseInfo; + } + + } + + private final Client client; + private final Predicate predicate; + + /** + * Constructs a remote cluster license checker with the specified license predicate for checking license compatibility. The predicate + * does not need to check for the active license state as this is handled by the remote cluster license checker. + * + * @param client the client + * @param predicate the license predicate + */ + public RemoteClusterLicenseChecker(final Client client, final Predicate predicate) { + this.client = client; + this.predicate = predicate; + } + + public static boolean isLicensePlatinumOrTrial(final XPackInfoResponse.LicenseInfo licenseInfo) { + final License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode()); + return mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL; + } + + /** + * Checks the specified clusters for license compatibility. The specified callback will be invoked once if all clusters are + * license-compatible, otherwise the specified callback will be invoked once on the first cluster that is not license-compatible. + * + * @param clusterAliases the cluster aliases to check + * @param listener a callback + */ + public void checkRemoteClusterLicenses(final List clusterAliases, final ActionListener listener) { + final Iterator clusterAliasesIterator = clusterAliases.iterator(); + if (clusterAliasesIterator.hasNext() == false) { + listener.onResponse(LicenseCheck.success()); + return; + } + + final AtomicReference clusterAlias = new AtomicReference<>(); + + final ActionListener infoListener = new ActionListener() { + + @Override + public void onResponse(final XPackInfoResponse xPackInfoResponse) { + final XPackInfoResponse.LicenseInfo licenseInfo = xPackInfoResponse.getLicenseInfo(); + if ((licenseInfo.getStatus() == LicenseStatus.ACTIVE) == false || predicate.test(licenseInfo) == false) { + listener.onResponse(LicenseCheck.failure(new RemoteClusterLicenseInfo(clusterAlias.get(), licenseInfo))); + return; + } + + if (clusterAliasesIterator.hasNext()) { + clusterAlias.set(clusterAliasesIterator.next()); + // recurse to the next cluster + remoteClusterLicense(clusterAlias.get(), this); + } else { + listener.onResponse(LicenseCheck.success()); + } + } + + @Override + public void onFailure(final Exception e) { + final String message = "could not determine the license type for cluster [" + clusterAlias.get() + "]"; + listener.onFailure(new ElasticsearchException(message, e)); + } + + }; + + // check the license on the first cluster, and then we recursively check licenses on the remaining clusters + clusterAlias.set(clusterAliasesIterator.next()); + remoteClusterLicense(clusterAlias.get(), infoListener); + } + + private void remoteClusterLicense(final String clusterAlias, final ActionListener listener) { + final ThreadContext threadContext = client.threadPool().getThreadContext(); + final ContextPreservingActionListener contextPreservingActionListener = + new ContextPreservingActionListener<>(threadContext.newRestorableContext(false), listener); + try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { + // we stash any context here since this is an internal execution and should not leak any existing context information + threadContext.markAsSystemContext(); + + final XPackInfoRequest request = new XPackInfoRequest(); + request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE)); + try { + client.getRemoteClusterClient(clusterAlias).execute(XPackInfoAction.INSTANCE, request, contextPreservingActionListener); + } catch (final Exception e) { + contextPreservingActionListener.onFailure(e); + } + } + } + + /** + * Predicate to test if the index name represents the name of a remote index. + * + * @param index the index name + * @return true if the collection of indices contains a remote index, otherwise false + */ + public static boolean isRemoteIndex(final String index) { + return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1; + } + + /** + * Predicate to test if the collection of index names contains any that represent the name of a remote index. + * + * @param indices the collection of index names + * @return true if the collection of index names contains a name that represents a remote index, otherwise false + */ + public static boolean containsRemoteIndex(final List indices) { + return indices.stream().anyMatch(RemoteClusterLicenseChecker::isRemoteIndex); + } + + /** + * Filters the collection of index names for names that represent a remote index. Remote index names are of the form + * {@code cluster_name:index_name}. + * + * @param indices the collection of index names + * @return list of index names that represent remote index names + */ + public static List remoteIndices(final List indices) { + return indices.stream().filter(RemoteClusterLicenseChecker::isRemoteIndex).collect(Collectors.toList()); + } + + /** + * Extract the list of remote cluster aliases from the list of index names. Remote index names are of the form + * {@code cluster_alias:index_name} and the cluster_alias is extracted for each index name that represents a remote index. + * + * @param indices the collection of index names + * @return the remote cluster names + */ + public static List remoteClusterAliases(final List indices) { + return indices.stream() + .filter(RemoteClusterLicenseChecker::isRemoteIndex) + .map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR))) + .distinct() + .collect(Collectors.toList()); + } + + /** + * Constructs an error message for license incompatibility. + * + * @param feature the name of the feature that initiated the remote cluster license check. + * @param remoteClusterLicenseInfo the remote cluster license info of the cluster that failed the license check + * @return an error message representing license incompatibility + */ + public static String buildErrorMessage( + final String feature, + final RemoteClusterLicenseInfo remoteClusterLicenseInfo, + final Predicate predicate) { + final StringBuilder error = new StringBuilder(); + if (remoteClusterLicenseInfo.licenseInfo().getStatus() != LicenseStatus.ACTIVE) { + error.append(String.format(Locale.ROOT, "the license on cluster [%s] is not active", remoteClusterLicenseInfo.clusterAlias())); + } else { + assert predicate.test(remoteClusterLicenseInfo.licenseInfo()) == false : "license must be incompatible to build error message"; + final String message = String.format( + Locale.ROOT, + "the license mode [%s] on cluster [%s] does not enable [%s]", + License.OperationMode.resolve(remoteClusterLicenseInfo.licenseInfo().getMode()), + remoteClusterLicenseInfo.clusterAlias(), + feature); + error.append(message); + } + + return error.toString(); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java new file mode 100644 index 00000000000..a8627d21542 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/license/RemoteClusterLicenseCheckerTests.java @@ -0,0 +1,414 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.license; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.protocol.xpack.XPackInfoResponse; +import org.elasticsearch.protocol.xpack.license.LicenseStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.action.XPackInfoAction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.argThat; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public final class RemoteClusterLicenseCheckerTests extends ESTestCase { + + public void testIsNotRemoteIndex() { + assertFalse(RemoteClusterLicenseChecker.isRemoteIndex("local-index")); + } + + public void testIsRemoteIndex() { + assertTrue(RemoteClusterLicenseChecker.isRemoteIndex("remote-cluster:remote-index")); + } + + public void testNoRemoteIndex() { + final List indices = Arrays.asList("local-index1", "local-index2"); + assertFalse(RemoteClusterLicenseChecker.containsRemoteIndex(indices)); + } + + public void testRemoteIndex() { + final List indices = Arrays.asList("local-index", "remote-cluster:remote-index"); + assertTrue(RemoteClusterLicenseChecker.containsRemoteIndex(indices)); + } + + public void testNoRemoteIndices() { + final List indices = Collections.singletonList("local-index"); + assertThat(RemoteClusterLicenseChecker.remoteIndices(indices), is(empty())); + } + + public void testRemoteIndices() { + final List indices = Arrays.asList("local-index1", "remote-cluster1:index1", "local-index2", "remote-cluster2:index1"); + assertThat( + RemoteClusterLicenseChecker.remoteIndices(indices), + containsInAnyOrder("remote-cluster1:index1", "remote-cluster2:index1")); + } + + public void testNoRemoteClusterAliases() { + final List indices = Arrays.asList("local-index1", "local-index2"); + assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), empty()); + } + + public void testOneRemoteClusterAlias() { + final List indices = Arrays.asList("local-index1", "remote-cluster1:remote-index1"); + assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1")); + } + + public void testMoreThanOneRemoteClusterAlias() { + final List indices = Arrays.asList("remote-cluster1:remote-index1", "local-index1", "remote-cluster2:remote-index1"); + assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1", "remote-cluster2")); + } + + public void testDuplicateRemoteClusterAlias() { + final List indices = Arrays.asList( + "remote-cluster1:remote-index1", "local-index1", "remote-cluster2:index1", "remote-cluster2:remote-index2"); + assertThat(RemoteClusterLicenseChecker.remoteClusterAliases(indices), contains("remote-cluster1", "remote-cluster2")); + } + + public void testCheckRemoteClusterLicensesGivenCompatibleLicenses() { + final AtomicInteger index = new AtomicInteger(); + final List responses = new ArrayList<>(); + + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClient(threadPool); + doAnswer(invocationMock -> { + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + final AtomicReference licenseCheck = new AtomicReference<>(); + + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + licenseCheck.set(response); + } + + @Override + public void onFailure(final Exception e) { + fail(e.getMessage()); + } + + })); + + verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any()); + assertNotNull(licenseCheck.get()); + assertTrue(licenseCheck.get().isSuccess()); + } + + public void testCheckRemoteClusterLicensesGivenIncompatibleLicense() { + final AtomicInteger index = new AtomicInteger(); + final List remoteClusterAliases = Arrays.asList("good", "cluster-with-basic-license", "good2"); + final List responses = new ArrayList<>(); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClient(threadPool); + doAnswer(invocationMock -> { + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + final AtomicReference licenseCheck = new AtomicReference<>(); + + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + licenseCheck.set(response); + } + + @Override + public void onFailure(final Exception e) { + fail(e.getMessage()); + } + + })); + + verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any()); + assertNotNull(licenseCheck.get()); + assertFalse(licenseCheck.get().isSuccess()); + assertThat(licenseCheck.get().remoteClusterLicenseInfo().clusterAlias(), equalTo("cluster-with-basic-license")); + assertThat(licenseCheck.get().remoteClusterLicenseInfo().licenseInfo().getType(), equalTo("BASIC")); + } + + public void testCheckRemoteClusterLicencesGivenNonExistentCluster() { + final AtomicInteger index = new AtomicInteger(); + final List responses = new ArrayList<>(); + + final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); + final String failingClusterAlias = randomFrom(remoteClusterAliases); + final ThreadPool threadPool = createMockThreadPool(); + final Client client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, failingClusterAlias); + doAnswer(invocationMock -> { + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + final AtomicReference exception = new AtomicReference<>(); + + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + fail(); + } + + @Override + public void onFailure(final Exception e) { + exception.set(e); + } + + })); + + assertNotNull(exception.get()); + assertThat(exception.get(), instanceOf(ElasticsearchException.class)); + assertThat(exception.get().getMessage(), equalTo("could not determine the license type for cluster [" + failingClusterAlias + "]")); + assertNotNull(exception.get().getCause()); + assertThat(exception.get().getCause(), instanceOf(IllegalArgumentException.class)); + } + + public void testRemoteClusterLicenseCallUsesSystemContext() throws InterruptedException { + final ThreadPool threadPool = new TestThreadPool(getTestName()); + + try { + final Client client = createMockClient(threadPool); + doAnswer(invocationMock -> { + assertTrue(threadPool.getThreadContext().isSystemContext()); + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + + final List remoteClusterAliases = Collections.singletonList("valid"); + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, doubleInvocationProtectingListener(ActionListener.wrap(() -> {}))); + + verify(client, times(1)).execute(same(XPackInfoAction.INSTANCE), any(), any()); + } finally { + terminate(threadPool); + } + } + + public void testListenerIsExecutedWithCallingContext() throws InterruptedException { + final AtomicInteger index = new AtomicInteger(); + final List responses = new ArrayList<>(); + + final ThreadPool threadPool = new TestThreadPool(getTestName()); + + try { + final List remoteClusterAliases = Arrays.asList("valid1", "valid2", "valid3"); + final Client client; + final boolean failure = randomBoolean(); + if (failure) { + client = createMockClientThatThrowsOnGetRemoteClusterClient(threadPool, randomFrom(remoteClusterAliases)); + } else { + client = createMockClient(threadPool); + } + doAnswer(invocationMock -> { + @SuppressWarnings("unchecked") ActionListener listener = + (ActionListener) invocationMock.getArguments()[2]; + listener.onResponse(responses.get(index.getAndIncrement())); + return null; + }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); + + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); + + final RemoteClusterLicenseChecker licenseChecker = + new RemoteClusterLicenseChecker(client, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial); + + final AtomicBoolean listenerInvoked = new AtomicBoolean(); + threadPool.getThreadContext().putHeader("key", "value"); + licenseChecker.checkRemoteClusterLicenses( + remoteClusterAliases, + doubleInvocationProtectingListener(new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + if (failure) { + fail(); + } + assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value")); + assertFalse(threadPool.getThreadContext().isSystemContext()); + listenerInvoked.set(true); + } + + @Override + public void onFailure(final Exception e) { + if (failure == false) { + fail(); + } + assertThat(threadPool.getThreadContext().getHeader("key"), equalTo("value")); + assertFalse(threadPool.getThreadContext().isSystemContext()); + listenerInvoked.set(true); + } + + })); + + assertTrue(listenerInvoked.get()); + } finally { + terminate(threadPool); + } + } + + public void testBuildErrorMessageForActiveCompatibleLicense() { + final XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse(); + final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info = + new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence); + final AssertionError e = expectThrows( + AssertionError.class, + () -> RemoteClusterLicenseChecker.buildErrorMessage("", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial)); + assertThat(e, hasToString(containsString("license must be incompatible to build error message"))); + } + + public void testBuildErrorMessageForIncompatibleLicense() { + final XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse(); + final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info = + new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense); + assertThat( + RemoteClusterLicenseChecker.buildErrorMessage("Feature", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial), + equalTo("the license mode [BASIC] on cluster [basic-cluster] does not enable [Feature]")); + } + + public void testBuildErrorMessageForInactiveLicense() { + final XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse(); + final RemoteClusterLicenseChecker.RemoteClusterLicenseInfo info = + new RemoteClusterLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense); + assertThat( + RemoteClusterLicenseChecker.buildErrorMessage("Feature", info, RemoteClusterLicenseChecker::isLicensePlatinumOrTrial), + equalTo("the license on cluster [expired-cluster] is not active")); + } + + private ActionListener doubleInvocationProtectingListener( + final ActionListener listener) { + final AtomicBoolean listenerInvoked = new AtomicBoolean(); + return new ActionListener() { + + @Override + public void onResponse(final RemoteClusterLicenseChecker.LicenseCheck response) { + if (listenerInvoked.compareAndSet(false, true) == false) { + fail("listener invoked twice"); + } + listener.onResponse(response); + } + + @Override + public void onFailure(final Exception e) { + if (listenerInvoked.compareAndSet(false, true) == false) { + fail("listener invoked twice"); + } + listener.onFailure(e); + } + + }; + } + + private ThreadPool createMockThreadPool() { + final ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + return threadPool; + } + + private Client createMockClient(final ThreadPool threadPool) { + return createMockClient(threadPool, client -> when(client.getRemoteClusterClient(anyString())).thenReturn(client)); + } + + private Client createMockClientThatThrowsOnGetRemoteClusterClient(final ThreadPool threadPool, final String clusterAlias) { + return createMockClient( + threadPool, + client -> { + when(client.getRemoteClusterClient(clusterAlias)).thenThrow(new IllegalArgumentException()); + when(client.getRemoteClusterClient(argThat(not(clusterAlias)))).thenReturn(client); + }); + } + + private Client createMockClient(final ThreadPool threadPool, final Consumer finish) { + final Client client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + finish.accept(client); + return client; + } + + private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong()); + } + + private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", LicenseStatus.ACTIVE, randomNonNegativeLong()); + } + + private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() { + return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.EXPIRED, randomNonNegativeLong()); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java index 0ea9eb77648..d6ebdd0449e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.RemoteClusterLicenseChecker; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.persistent.PersistentTaskState; @@ -46,10 +47,10 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; import org.elasticsearch.xpack.ml.datafeed.DatafeedNodeSelector; -import org.elasticsearch.xpack.ml.datafeed.MlRemoteLicenseChecker; import org.elasticsearch.xpack.ml.datafeed.extractor.DataExtractorFactory; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Predicate; @@ -141,19 +142,22 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction { - if (response.isViolated()) { + if (response.isSuccess() == false) { listener.onFailure(createUnlicensedError(datafeed.getId(), response)); } else { createDataExtractor(job, datafeed, params, waitForTaskListener); } }, - e -> listener.onFailure(createUnknownLicenseError(datafeed.getId(), - MlRemoteLicenseChecker.remoteIndices(datafeed.getIndices()), e)) + e -> listener.onFailure( + createUnknownLicenseError( + datafeed.getId(), RemoteClusterLicenseChecker.remoteIndices(datafeed.getIndices()), e)) )); } else { createDataExtractor(job, datafeed, params, waitForTaskListener); @@ -232,23 +236,35 @@ public class TransportStartDatafeedAction extends TransportMasterNodeAction remoteIndices, - Exception cause) { - String message = "Cannot start datafeed [" + datafeedId + "] as it is configured to use" - + " indices on a remote cluster " + remoteIndices - + " but the license type could not be verified"; + private ElasticsearchStatusException createUnknownLicenseError( + final String datafeedId, final List remoteIndices, final Exception cause) { + final int numberOfRemoteClusters = RemoteClusterLicenseChecker.remoteClusterAliases(remoteIndices).size(); + assert numberOfRemoteClusters > 0; + final String remoteClusterQualifier = numberOfRemoteClusters == 1 ? "a remote cluster" : "remote clusters"; + final String licenseTypeQualifier = numberOfRemoteClusters == 1 ? "" : "s"; + final String message = String.format( + Locale.ROOT, + "cannot start datafeed [%s] as it uses indices on %s %s but the license type%s could not be verified", + datafeedId, + remoteClusterQualifier, + remoteIndices, + licenseTypeQualifier); - return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, new Exception(cause.getMessage())); + return new ElasticsearchStatusException(message, RestStatus.BAD_REQUEST, cause); } public static class StartDatafeedPersistentTasksExecutor extends PersistentTasksExecutor { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java index a6be0476486..ce3f611b222 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/DatafeedNodeSelector.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.logging.Loggers; +import org.elasticsearch.license.RemoteClusterLicenseChecker; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; @@ -92,7 +93,7 @@ public class DatafeedNodeSelector { List indices = datafeed.getIndices(); for (String index : indices) { - if (MlRemoteLicenseChecker.isRemoteIndex(index)) { + if (RemoteClusterLicenseChecker.isRemoteIndex(index)) { // We cannot verify remote indices continue; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java deleted file mode 100644 index b0eeed2c800..00000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseChecker.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ - -package org.elasticsearch.xpack.ml.datafeed; - -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.license.License; -import org.elasticsearch.protocol.xpack.XPackInfoRequest; -import org.elasticsearch.protocol.xpack.XPackInfoResponse; -import org.elasticsearch.protocol.xpack.license.LicenseStatus; -import org.elasticsearch.transport.ActionNotFoundTransportException; -import org.elasticsearch.transport.RemoteClusterAware; -import org.elasticsearch.xpack.core.action.XPackInfoAction; - -import java.util.EnumSet; -import java.util.Iterator; -import java.util.List; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -/** - * ML datafeeds can use cross cluster search to access data in a remote cluster. - * The remote cluster should be licenced for ML this class performs that check - * using the _xpack (info) endpoint. - */ -public class MlRemoteLicenseChecker { - - private final Client client; - - public static class RemoteClusterLicenseInfo { - private final String clusterName; - private final XPackInfoResponse.LicenseInfo licenseInfo; - - RemoteClusterLicenseInfo(String clusterName, XPackInfoResponse.LicenseInfo licenseInfo) { - this.clusterName = clusterName; - this.licenseInfo = licenseInfo; - } - - public String getClusterName() { - return clusterName; - } - - public XPackInfoResponse.LicenseInfo getLicenseInfo() { - return licenseInfo; - } - } - - public class LicenseViolation { - private final RemoteClusterLicenseInfo licenseInfo; - - private LicenseViolation(@Nullable RemoteClusterLicenseInfo licenseInfo) { - this.licenseInfo = licenseInfo; - } - - public boolean isViolated() { - return licenseInfo != null; - } - - public RemoteClusterLicenseInfo get() { - return licenseInfo; - } - } - - public MlRemoteLicenseChecker(Client client) { - this.client = client; - } - - /** - * Check each cluster is licensed for ML. - * This function evaluates lazily and will terminate when the first cluster - * that is not licensed is found or an error occurs. - * - * @param clusterNames List of remote cluster names - * @param listener Response listener - */ - public void checkRemoteClusterLicenses(List clusterNames, ActionListener listener) { - final Iterator itr = clusterNames.iterator(); - if (itr.hasNext() == false) { - listener.onResponse(new LicenseViolation(null)); - return; - } - - final AtomicReference clusterName = new AtomicReference<>(itr.next()); - - ActionListener infoListener = new ActionListener() { - @Override - public void onResponse(XPackInfoResponse xPackInfoResponse) { - if (licenseSupportsML(xPackInfoResponse.getLicenseInfo()) == false) { - listener.onResponse(new LicenseViolation( - new RemoteClusterLicenseInfo(clusterName.get(), xPackInfoResponse.getLicenseInfo()))); - return; - } - - if (itr.hasNext()) { - clusterName.set(itr.next()); - remoteClusterLicense(clusterName.get(), this); - } else { - listener.onResponse(new LicenseViolation(null)); - } - } - - @Override - public void onFailure(Exception e) { - String message = "Could not determine the X-Pack licence type for cluster [" + clusterName.get() + "]"; - if (e instanceof ActionNotFoundTransportException) { - // This is likely to be because x-pack is not installed in the target cluster - message += ". Is X-Pack installed on the target cluster?"; - } - listener.onFailure(new ElasticsearchException(message, e)); - } - }; - - remoteClusterLicense(clusterName.get(), infoListener); - } - - private void remoteClusterLicense(String clusterName, ActionListener listener) { - Client remoteClusterClient = client.getRemoteClusterClient(clusterName); - ThreadContext threadContext = remoteClusterClient.threadPool().getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - // we stash any context here since this is an internal execution and should not leak any - // existing context information. - threadContext.markAsSystemContext(); - - XPackInfoRequest request = new XPackInfoRequest(); - request.setCategories(EnumSet.of(XPackInfoRequest.Category.LICENSE)); - remoteClusterClient.execute(XPackInfoAction.INSTANCE, request, listener); - } - } - - static boolean licenseSupportsML(XPackInfoResponse.LicenseInfo licenseInfo) { - License.OperationMode mode = License.OperationMode.resolve(licenseInfo.getMode()); - return licenseInfo.getStatus() == LicenseStatus.ACTIVE && - (mode == License.OperationMode.PLATINUM || mode == License.OperationMode.TRIAL); - } - - public static boolean isRemoteIndex(String index) { - return index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR) != -1; - } - - public static boolean containsRemoteIndex(List indices) { - return indices.stream().anyMatch(MlRemoteLicenseChecker::isRemoteIndex); - } - - /** - * Get any remote indices used in cross cluster search. - * Remote indices are of the form {@code cluster_name:index_name} - * @return List of remote cluster indices - */ - public static List remoteIndices(List indices) { - return indices.stream().filter(MlRemoteLicenseChecker::isRemoteIndex).collect(Collectors.toList()); - } - - /** - * Extract the list of remote cluster names from the list of indices. - * @param indices List of indices. Remote cluster indices are prefixed - * with {@code cluster-name:} - * @return Every cluster name found in {@code indices} - */ - public static List remoteClusterNames(List indices) { - return indices.stream() - .filter(MlRemoteLicenseChecker::isRemoteIndex) - .map(index -> index.substring(0, index.indexOf(RemoteClusterAware.REMOTE_CLUSTER_INDEX_SEPARATOR))) - .distinct() - .collect(Collectors.toList()); - } - - public static String buildErrorMessage(RemoteClusterLicenseInfo clusterLicenseInfo) { - StringBuilder error = new StringBuilder(); - if (clusterLicenseInfo.licenseInfo.getStatus() != LicenseStatus.ACTIVE) { - error.append("The license on cluster [").append(clusterLicenseInfo.clusterName) - .append("] is not active. "); - } else { - License.OperationMode mode = License.OperationMode.resolve(clusterLicenseInfo.licenseInfo.getMode()); - if (mode != License.OperationMode.PLATINUM && mode != License.OperationMode.TRIAL) { - error.append("The license mode [").append(mode) - .append("] on cluster [") - .append(clusterLicenseInfo.clusterName) - .append("] does not enable Machine Learning. "); - } - } - - error.append(Strings.toString(clusterLicenseInfo.licenseInfo)); - return error.toString(); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java index 72c8d361dd8..610a5c1b92f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDatafeedActionTests.java @@ -3,10 +3,12 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ + package org.elasticsearch.xpack.ml.action; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.MlMetadata; @@ -14,7 +16,6 @@ import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; -import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; import org.elasticsearch.xpack.ml.datafeed.DatafeedManagerTests; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java deleted file mode 100644 index 81e4c75cfad..00000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/MlRemoteLicenseCheckerTests.java +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ - -package org.elasticsearch.xpack.ml.datafeed; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.protocol.xpack.XPackInfoResponse; -import org.elasticsearch.protocol.xpack.license.LicenseStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.core.action.XPackInfoAction; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; -import static org.mockito.Matchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class MlRemoteLicenseCheckerTests extends ESTestCase { - - public void testIsRemoteIndex() { - List indices = Arrays.asList("local-index1", "local-index2"); - assertFalse(MlRemoteLicenseChecker.containsRemoteIndex(indices)); - indices = Arrays.asList("local-index1", "remote-cluster:remote-index2"); - assertTrue(MlRemoteLicenseChecker.containsRemoteIndex(indices)); - } - - public void testRemoteIndices() { - List indices = Collections.singletonList("local-index"); - assertThat(MlRemoteLicenseChecker.remoteIndices(indices), is(empty())); - indices = Arrays.asList("local-index", "remote-cluster:index1", "local-index2", "remote-cluster2:index1"); - assertThat(MlRemoteLicenseChecker.remoteIndices(indices), containsInAnyOrder("remote-cluster:index1", "remote-cluster2:index1")); - } - - public void testRemoteClusterNames() { - List indices = Arrays.asList("local-index1", "local-index2"); - assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), empty()); - indices = Arrays.asList("local-index1", "remote-cluster1:remote-index2"); - assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1")); - indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1"); - assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2")); - indices = Arrays.asList("remote-cluster1:index2", "index1", "remote-cluster2:index1", "remote-cluster2:index2"); - assertThat(MlRemoteLicenseChecker.remoteClusterNames(indices), contains("remote-cluster1", "remote-cluster2")); - } - - public void testLicenseSupportsML() { - XPackInfoResponse.LicenseInfo licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", - LicenseStatus.ACTIVE, randomNonNegativeLong()); - assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); - - licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "trial", "trial", LicenseStatus.EXPIRED, randomNonNegativeLong()); - assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); - - licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "GOLD", "GOLD", LicenseStatus.ACTIVE, randomNonNegativeLong()); - assertFalse(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); - - licenseInfo = new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong()); - assertTrue(MlRemoteLicenseChecker.licenseSupportsML(licenseInfo)); - } - - public void testCheckRemoteClusterLicenses_givenValidLicenses() { - final AtomicInteger index = new AtomicInteger(0); - final List responses = new ArrayList<>(); - - Client client = createMockClient(); - doAnswer(invocationMock -> { - @SuppressWarnings("raw_types") - ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; - listener.onResponse(responses.get(index.getAndIncrement())); - return null; - }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); - - - List remoteClusterNames = Arrays.asList("valid1", "valid2", "valid3"); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - - MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client); - AtomicReference licCheckResponse = new AtomicReference<>(); - - licenseChecker.checkRemoteClusterLicenses(remoteClusterNames, - new ActionListener() { - @Override - public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) { - licCheckResponse.set(response); - } - - @Override - public void onFailure(Exception e) { - fail(e.getMessage()); - } - }); - - verify(client, times(3)).execute(same(XPackInfoAction.INSTANCE), any(), any()); - assertNotNull(licCheckResponse.get()); - assertFalse(licCheckResponse.get().isViolated()); - assertNull(licCheckResponse.get().get()); - } - - public void testCheckRemoteClusterLicenses_givenInvalidLicense() { - final AtomicInteger index = new AtomicInteger(0); - List remoteClusterNames = Arrays.asList("good", "cluster-with-basic-license", "good2"); - final List responses = new ArrayList<>(); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - responses.add(new XPackInfoResponse(null, createBasicLicenseResponse(), null)); - responses.add(new XPackInfoResponse(null, createPlatinumLicenseResponse(), null)); - - Client client = createMockClient(); - doAnswer(invocationMock -> { - @SuppressWarnings("raw_types") - ActionListener listener = (ActionListener) invocationMock.getArguments()[2]; - listener.onResponse(responses.get(index.getAndIncrement())); - return null; - }).when(client).execute(same(XPackInfoAction.INSTANCE), any(), any()); - - MlRemoteLicenseChecker licenseChecker = new MlRemoteLicenseChecker(client); - AtomicReference licCheckResponse = new AtomicReference<>(); - - licenseChecker.checkRemoteClusterLicenses(remoteClusterNames, - new ActionListener() { - @Override - public void onResponse(MlRemoteLicenseChecker.LicenseViolation response) { - licCheckResponse.set(response); - } - - @Override - public void onFailure(Exception e) { - fail(e.getMessage()); - } - }); - - verify(client, times(2)).execute(same(XPackInfoAction.INSTANCE), any(), any()); - assertNotNull(licCheckResponse.get()); - assertTrue(licCheckResponse.get().isViolated()); - assertEquals("cluster-with-basic-license", licCheckResponse.get().get().getClusterName()); - assertEquals("BASIC", licCheckResponse.get().get().getLicenseInfo().getType()); - } - - public void testBuildErrorMessage() { - XPackInfoResponse.LicenseInfo platinumLicence = createPlatinumLicenseResponse(); - MlRemoteLicenseChecker.RemoteClusterLicenseInfo info = - new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("platinum-cluster", platinumLicence); - assertEquals(Strings.toString(platinumLicence), MlRemoteLicenseChecker.buildErrorMessage(info)); - - XPackInfoResponse.LicenseInfo basicLicense = createBasicLicenseResponse(); - info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("basic-cluster", basicLicense); - String expected = "The license mode [BASIC] on cluster [basic-cluster] does not enable Machine Learning. " - + Strings.toString(basicLicense); - assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info)); - - XPackInfoResponse.LicenseInfo expiredLicense = createExpiredLicenseResponse(); - info = new MlRemoteLicenseChecker.RemoteClusterLicenseInfo("expired-cluster", expiredLicense); - expected = "The license on cluster [expired-cluster] is not active. " + Strings.toString(expiredLicense); - assertEquals(expected, MlRemoteLicenseChecker.buildErrorMessage(info)); - } - - private Client createMockClient() { - Client client = mock(Client.class); - ThreadPool threadPool = mock(ThreadPool.class); - when(client.threadPool()).thenReturn(threadPool); - when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); - when(client.getRemoteClusterClient(anyString())).thenReturn(client); - return client; - } - - private XPackInfoResponse.LicenseInfo createPlatinumLicenseResponse() { - return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.ACTIVE, randomNonNegativeLong()); - } - - private XPackInfoResponse.LicenseInfo createBasicLicenseResponse() { - return new XPackInfoResponse.LicenseInfo("uid", "BASIC", "BASIC", LicenseStatus.ACTIVE, randomNonNegativeLong()); - } - - private XPackInfoResponse.LicenseInfo createExpiredLicenseResponse() { - return new XPackInfoResponse.LicenseInfo("uid", "PLATINUM", "PLATINUM", LicenseStatus.EXPIRED, randomNonNegativeLong()); - } -}