From 19fc53296166baf40aba356519bd7501639b8e03 Mon Sep 17 00:00:00 2001 From: Colin Goodheart-Smithe Date: Thu, 16 Feb 2017 09:40:53 +0000 Subject: [PATCH] Restricts certain ml endpoints if license forbids ml (elastic/x-pack-elasticsearch#568) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This change adds checks to the following machine learning actions to check that machine learning is permitted by the license before executing the request and throw and exception if the license forbids it: * Create job * Create data feed * Open job * Start data feed This is not a final list of the restrictions we want to have in place if the license forbids ML access but it serves as a starting point and can be easily updated following further discussions. The change also moves the `transportClientMode` check to `createComponents` so we don’t try to start the server parts of machine learning (job manager, connection to named pipes etc.) if we are running in a transport client. Original commit: elastic/x-pack-elasticsearch@6c19ebd3bcace999b1289800255658b2e3ac61d0 --- .../org/elasticsearch/xpack/XPackPlugin.java | 2 +- .../xpack/ml/MachineLearning.java | 10 +- .../xpack/ml/action/OpenJobAction.java | 31 +- .../xpack/ml/action/PutDatafeedAction.java | 20 +- .../xpack/ml/action/PutJobAction.java | 17 +- .../xpack/ml/action/StartDatafeedAction.java | 17 +- .../ml/client/MachineLearningClient.java | 158 ++++++++ .../MachineLearningLicensingTests.java | 378 ++++++++++++++++++ .../xpack/ml/support/BaseMlIntegTestCase.java | 7 + 9 files changed, 613 insertions(+), 27 deletions(-) create mode 100644 plugin/src/main/java/org/elasticsearch/xpack/ml/client/MachineLearningClient.java create mode 100644 plugin/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java diff --git a/plugin/src/main/java/org/elasticsearch/xpack/XPackPlugin.java b/plugin/src/main/java/org/elasticsearch/xpack/XPackPlugin.java index 0d64eac56db..158830bb6fe 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/XPackPlugin.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/XPackPlugin.java @@ -230,7 +230,7 @@ public class XPackPlugin extends Plugin implements ScriptPlugin, ActionPlugin, I modules.addAll(monitoring.nodeModules()); modules.addAll(watcher.nodeModules()); modules.addAll(graph.createGuiceModules()); - modules.addAll(machineLearning.createGuiceModules()); + modules.addAll(machineLearning.nodeModules()); if (transportClientMode) { modules.add(b -> b.bind(XPackLicenseState.class).toProvider(Providers.of(null))); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 3be27816ff1..d091d04ebcc 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -18,7 +18,6 @@ import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.inject.Module; -import org.elasticsearch.common.inject.util.Providers; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; @@ -132,8 +131,6 @@ import org.elasticsearch.xpack.persistent.PersistentTasksInProgress; import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction; import org.elasticsearch.xpack.persistent.StartPersistentTaskAction; import org.elasticsearch.xpack.persistent.UpdatePersistentTaskStatusAction; -import org.elasticsearch.xpack.watcher.WatcherFeatureSet; -import org.elasticsearch.xpack.watcher.WatcherService; import java.io.IOException; import java.util.ArrayList; @@ -240,7 +237,7 @@ public class MachineLearning extends Plugin implements ActionPlugin { public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry) { - if (false == enabled) { + if (false == enabled || this.transportClientMode) { return emptyList(); } // Whether we are using native process is a good way to detect whether we are in dev / test mode: @@ -298,10 +295,7 @@ public class MachineLearning extends Plugin implements ActionPlugin { public Collection nodeModules() { List modules = new ArrayList<>(); modules.add(b -> { - XPackPlugin.bindFeatureSet(b, WatcherFeatureSet.class); - if (transportClientMode || enabled == false) { - b.bind(WatcherService.class).toProvider(Providers.of(null)); - } + XPackPlugin.bindFeatureSet(b, MachineLearningFeatureSet.class); }); return modules; diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java index ef74d2753a4..fae7c10bafb 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/OpenJobAction.java @@ -30,12 +30,15 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.XPackPlugin; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.job.config.Job; import org.elasticsearch.xpack.ml.job.config.JobState; @@ -242,14 +245,16 @@ public class OpenJobAction extends Action listener) { - // If we already know that we can't find an ml node because all ml nodes are running at capacity or - // simply because there are no ml nodes in the cluster then we fail quickly here: - ClusterState clusterState = clusterService.state(); - if (selectLeastLoadedMlNode(request.getJobId(), clusterState, logger) == null) { - throw new ElasticsearchStatusException("no nodes available to open job [" + request.getJobId() + "]", - RestStatus.TOO_MANY_REQUESTS); - } + if (licenseState.isMachineLearningAllowed()) { + // If we already know that we can't find an ml node because all ml nodes are running at capacity or + // simply because there are no ml nodes in the cluster then we fail quickly here: + ClusterState clusterState = clusterService.state(); + if (selectLeastLoadedMlNode(request.getJobId(), clusterState, logger) == null) { + throw new ElasticsearchStatusException("no nodes available to open job [" + request.getJobId() + "]", + RestStatus.TOO_MANY_REQUESTS); + } - ActionListener finalListener = - ActionListener.wrap(response -> waitForJobStarted(request, response, listener), listener::onFailure); - super.doExecute(request, finalListener); + ActionListener finalListener = + ActionListener.wrap(response -> waitForJobStarted(request, response, listener), listener::onFailure); + super.doExecute(request, finalListener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING)); + } } void waitForJobStarted(Request request, PersistentActionResponse response, ActionListener listener) { diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutDatafeedAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutDatafeedAction.java index 413356e1bd8..bbde9370209 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutDatafeedAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutDatafeedAction.java @@ -29,10 +29,14 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.ml.job.metadata.MlMetadata; +import org.elasticsearch.xpack.XPackPlugin; import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig; +import org.elasticsearch.xpack.ml.job.metadata.MlMetadata; import java.io.IOException; import java.util.Objects; @@ -175,12 +179,15 @@ public class PutDatafeedAction extends Action { + private XPackLicenseState licenseState; + @Inject public TransportAction(Settings settings, TransportService transportService, ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, + ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver) { super(settings, PutDatafeedAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, Request::new); + this.licenseState = licenseState; } @Override @@ -226,5 +233,14 @@ public class PutDatafeedAction extends Action listener) { + if (licenseState.isMachineLearningAllowed()) { + super.doExecute(task, request, listener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING)); + } + } } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutJobAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutJobAction.java index 8e7c6e4b82b..959515343ef 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutJobAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/PutJobAction.java @@ -28,8 +28,12 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.XPackPlugin; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.config.Job; @@ -180,13 +184,15 @@ public class PutJobAction extends Action { private final JobManager jobManager; + private XPackLicenseState licenseState; @Inject public TransportAction(Settings settings, TransportService transportService, ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, + ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, JobManager jobManager) { super(settings, PutJobAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, Request::new); + this.licenseState = licenseState; this.jobManager = jobManager; } @@ -209,5 +215,14 @@ public class PutJobAction extends Action listener) { + if (licenseState.isMachineLearningAllowed()) { + super.doExecute(task, request, listener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING)); + } + } } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/StartDatafeedAction.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/StartDatafeedAction.java index b2ce65d6c82..cf4c4dfd5f7 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/ml/action/StartDatafeedAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/action/StartDatafeedAction.java @@ -28,12 +28,15 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.XPackPlugin; import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.ml.datafeed.DatafeedJobRunner; import org.elasticsearch.xpack.ml.datafeed.DatafeedJobValidator; @@ -244,23 +247,29 @@ public class StartDatafeedAction private final DatafeedStateObserver observer; private final DatafeedJobRunner datafeedJobRunner; + private XPackLicenseState licenseState; @Inject - public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool, + public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool, XPackLicenseState licenseState, PersistentActionService persistentActionService, PersistentActionRegistry persistentActionRegistry, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, ClusterService clusterService, DatafeedJobRunner datafeedJobRunner) { super(settings, NAME, false, threadPool, transportService, persistentActionService, persistentActionRegistry, actionFilters, indexNameExpressionResolver, Request::new, ThreadPool.Names.MANAGEMENT); + this.licenseState = licenseState; this.datafeedJobRunner = datafeedJobRunner; this.observer = new DatafeedStateObserver(threadPool, clusterService); } @Override protected void doExecute(Request request, ActionListener listener) { - ActionListener finalListener = - ActionListener.wrap(response -> waitForDatafeedStarted(request, response, listener), listener::onFailure); - super.doExecute(request, finalListener); + if (licenseState.isMachineLearningAllowed()) { + ActionListener finalListener = ActionListener + .wrap(response -> waitForDatafeedStarted(request, response, listener), listener::onFailure); + super.doExecute(request, finalListener); + } else { + listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING)); + } } void waitForDatafeedStarted(Request request, diff --git a/plugin/src/main/java/org/elasticsearch/xpack/ml/client/MachineLearningClient.java b/plugin/src/main/java/org/elasticsearch/xpack/ml/client/MachineLearningClient.java new file mode 100644 index 00000000000..7301a33e0f3 --- /dev/null +++ b/plugin/src/main/java/org/elasticsearch/xpack/ml/client/MachineLearningClient.java @@ -0,0 +1,158 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.client; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.xpack.ml.action.CloseJobAction; +import org.elasticsearch.xpack.ml.action.DeleteDatafeedAction; +import org.elasticsearch.xpack.ml.action.DeleteFilterAction; +import org.elasticsearch.xpack.ml.action.DeleteJobAction; +import org.elasticsearch.xpack.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.FlushJobAction; +import org.elasticsearch.xpack.ml.action.GetBucketsAction; +import org.elasticsearch.xpack.ml.action.GetCategoriesAction; +import org.elasticsearch.xpack.ml.action.GetDatafeedsAction; +import org.elasticsearch.xpack.ml.action.GetDatafeedsStatsAction; +import org.elasticsearch.xpack.ml.action.GetFiltersAction; +import org.elasticsearch.xpack.ml.action.GetInfluencersAction; +import org.elasticsearch.xpack.ml.action.GetJobsAction; +import org.elasticsearch.xpack.ml.action.GetJobsStatsAction; +import org.elasticsearch.xpack.ml.action.GetModelSnapshotsAction; +import org.elasticsearch.xpack.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.ml.action.OpenJobAction; +import org.elasticsearch.xpack.ml.action.PostDataAction; +import org.elasticsearch.xpack.ml.action.PutDatafeedAction; +import org.elasticsearch.xpack.ml.action.PutFilterAction; +import org.elasticsearch.xpack.ml.action.PutJobAction; +import org.elasticsearch.xpack.ml.action.RevertModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.StartDatafeedAction; +import org.elasticsearch.xpack.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.ml.action.UpdateDatafeedAction; +import org.elasticsearch.xpack.ml.action.UpdateJobAction; +import org.elasticsearch.xpack.ml.action.UpdateModelSnapshotAction; +import org.elasticsearch.xpack.persistent.PersistentActionResponse; +import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction; + +public class MachineLearningClient { + + private final ElasticsearchClient client; + + public MachineLearningClient(ElasticsearchClient client) { + this.client = client; + } + + public void closeJob(CloseJobAction.Request request, ActionListener listener) { + client.execute(CloseJobAction.INSTANCE, request, listener); + } + + public void deleteDatafeed(DeleteDatafeedAction.Request request, ActionListener listener) { + client.execute(DeleteDatafeedAction.INSTANCE, request, listener); + } + + public void deleteFilter(DeleteFilterAction.Request request, ActionListener listener) { + client.execute(DeleteFilterAction.INSTANCE, request, listener); + } + + public void deleteJob(DeleteJobAction.Request request, ActionListener listener) { + client.execute(DeleteJobAction.INSTANCE, request, listener); + } + + public void deleteModelSnapshot(DeleteModelSnapshotAction.Request request, + ActionListener listener) { + client.execute(DeleteModelSnapshotAction.INSTANCE, request, listener); + } + + public void flushJob(FlushJobAction.Request request, ActionListener listener) { + client.execute(FlushJobAction.INSTANCE, request, listener); + } + + public void getBuckets(GetBucketsAction.Request request, ActionListener listener) { + client.execute(GetBucketsAction.INSTANCE, request, listener); + } + + public void getCategories(GetCategoriesAction.Request request, ActionListener listener) { + client.execute(GetCategoriesAction.INSTANCE, request, listener); + } + + public void getDatafeeds(GetDatafeedsAction.Request request, ActionListener listener) { + client.execute(GetDatafeedsAction.INSTANCE, request, listener); + } + + public void getDatafeedsStats(GetDatafeedsStatsAction.Request request, ActionListener listener) { + client.execute(GetDatafeedsStatsAction.INSTANCE, request, listener); + } + + public void getFilters(GetFiltersAction.Request request, ActionListener listener) { + client.execute(GetFiltersAction.INSTANCE, request, listener); + } + + public void getInfluencers(GetInfluencersAction.Request request, ActionListener listener) { + client.execute(GetInfluencersAction.INSTANCE, request, listener); + } + + public void getJobs(GetJobsAction.Request request, ActionListener listener) { + client.execute(GetJobsAction.INSTANCE, request, listener); + } + + public void getJobsStats(GetJobsStatsAction.Request request, ActionListener listener) { + client.execute(GetJobsStatsAction.INSTANCE, request, listener); + } + + public void getModelSnapshots(GetModelSnapshotsAction.Request request, ActionListener listener) { + client.execute(GetModelSnapshotsAction.INSTANCE, request, listener); + } + + public void getRecords(GetRecordsAction.Request request, ActionListener listener) { + client.execute(GetRecordsAction.INSTANCE, request, listener); + } + + public void openJob(OpenJobAction.Request request, ActionListener listener) { + client.execute(OpenJobAction.INSTANCE, request, listener); + } + + public void postData(PostDataAction.Request request, ActionListener listener) { + client.execute(PostDataAction.INSTANCE, request, listener); + } + + public void putDatafeed(PutDatafeedAction.Request request, ActionListener listener) { + client.execute(PutDatafeedAction.INSTANCE, request, listener); + } + + public void putFilter(PutFilterAction.Request request, ActionListener listener) { + client.execute(PutFilterAction.INSTANCE, request, listener); + } + + public void putJob(PutJobAction.Request request, ActionListener listener) { + client.execute(PutJobAction.INSTANCE, request, listener); + } + + public void revertModelSnapshot(RevertModelSnapshotAction.Request request, + ActionListener listener) { + client.execute(RevertModelSnapshotAction.INSTANCE, request, listener); + } + + public void startDatafeed(StartDatafeedAction.Request request, ActionListener listener) { + client.execute(StartDatafeedAction.INSTANCE, request, listener); + } + + public void stopDatafeed(StopDatafeedAction.Request request, ActionListener listener) { + client.execute(StopDatafeedAction.INSTANCE, request, listener); + } + + public void updateDatafeed(UpdateDatafeedAction.Request request, ActionListener listener) { + client.execute(UpdateDatafeedAction.INSTANCE, request, listener); + } + + public void updateJob(UpdateJobAction.Request request, ActionListener listener) { + client.execute(UpdateJobAction.INSTANCE, request, listener); + } + + public void updateModelSnapshot(UpdateModelSnapshotAction.Request request, + ActionListener listener) { + client.execute(UpdateModelSnapshotAction.INSTANCE, request, listener); + } +} diff --git a/plugin/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java b/plugin/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java new file mode 100644 index 00000000000..c288fe198ee --- /dev/null +++ b/plugin/src/test/java/org/elasticsearch/license/MachineLearningLicensingTests.java @@ -0,0 +1,378 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.license; + +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.support.PlainListenableActionFuture; +import org.elasticsearch.client.transport.TransportClient; +import org.elasticsearch.license.License.OperationMode; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.xpack.TestXPackTransportClient; +import org.elasticsearch.xpack.XPackPlugin; +import org.elasticsearch.xpack.ml.action.CloseJobAction; +import org.elasticsearch.xpack.ml.action.DeleteDatafeedAction; +import org.elasticsearch.xpack.ml.action.DeleteJobAction; +import org.elasticsearch.xpack.ml.action.OpenJobAction; +import org.elasticsearch.xpack.ml.action.PutDatafeedAction; +import org.elasticsearch.xpack.ml.action.PutJobAction; +import org.elasticsearch.xpack.ml.action.StartDatafeedAction; +import org.elasticsearch.xpack.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.ml.client.MachineLearningClient; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.elasticsearch.xpack.persistent.PersistentActionResponse; +import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction; +import org.junit.Before; + +import java.util.Collections; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.is; + +public class MachineLearningLicensingTests extends BaseMlIntegTestCase { + + @Before + public void resetLicensing() { + enableLicensing(); + + ensureStableCluster(1); + ensureYellow(); + } + + public void testMachineLearningPutJobActionRestricted() throws Exception { + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + // test that license restricted apis do not work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), listener); + listener.actionGet(); + fail("put job action should not be enabled!"); + } catch (ElasticsearchSecurityException e) { + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING)); + } + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), listener); + PutJobAction.Response response = listener.actionGet(); + assertNotNull(response); + } + } + + public void testMachineLearningOpenJobActionRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response response = putJobListener.actionGet(); + assertNotNull(response); + } + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + // test that license restricted apis do not work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), listener); + listener.actionGet(); + fail("open job action should not be enabled!"); + } catch (ElasticsearchSecurityException e) { + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING)); + } + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), listener); + PersistentActionResponse response = listener.actionGet(); + assertNotNull(response); + } + } + + public void testMachineLearningPutDatafeedActionRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + } + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + // test that license restricted apis do not work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putDatafeed( + new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), listener); + listener.actionGet(); + fail("put datafeed action should not be enabled!"); + } catch (ElasticsearchSecurityException e) { + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING)); + } + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putDatafeed( + new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), listener); + PutDatafeedAction.Response response = listener.actionGet(); + assertNotNull(response); + } + } + + public void testMachineLearningStartDatafeedActionRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + PlainListenableActionFuture putDatafeedListener = new PlainListenableActionFuture<>( + client.threadPool()); + new MachineLearningClient(client).putDatafeed( + new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener); + PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet(); + assertNotNull(putDatafeedResponse); + PlainListenableActionFuture openJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener); + PersistentActionResponse openJobResponse = openJobListener.actionGet(); + assertNotNull(openJobResponse); + } + + // Pick a license that does not allow machine learning + License.OperationMode mode = randomInvalidLicenseType(); + enableLicensing(mode); + assertMLAllowed(false); + // test that license restricted apis do not work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), listener); + listener.actionGet(); + fail("start datafeed action should not be enabled!"); + } catch (ElasticsearchSecurityException e) { + assertThat(e.status(), is(RestStatus.FORBIDDEN)); + assertThat(e.getMessage(), containsString("non-compliant")); + assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING)); + } + + // Pick a license that does allow machine learning + mode = randomValidLicenseType(); + enableLicensing(mode); + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), listener); + PersistentActionResponse response = listener.actionGet(); + assertNotNull(response); + } + } + + public void testMachineLearningStopDatafeedActionNotRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + PlainListenableActionFuture putDatafeedListener = new PlainListenableActionFuture<>( + client.threadPool()); + new MachineLearningClient(client).putDatafeed( + new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener); + PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet(); + assertNotNull(putDatafeedResponse); + PlainListenableActionFuture openJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener); + PersistentActionResponse openJobResponse = openJobListener.actionGet(); + assertNotNull(openJobResponse); + PlainListenableActionFuture startDatafeedListener = new PlainListenableActionFuture<>( + client.threadPool()); + new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), startDatafeedListener); + PersistentActionResponse startDatafeedResponse = startDatafeedListener.actionGet(); + assertNotNull(startDatafeedResponse); + } + + // Pick a random license + License.OperationMode mode = randomLicenseType(); + enableLicensing(mode); + + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>( + client.threadPool()); + new MachineLearningClient(client).stopDatafeed(new StopDatafeedAction.Request("foobar"), listener); + listener.actionGet(); + } + } + + public void testMachineLearningCloseJobActionNotRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + PlainListenableActionFuture openJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener); + PersistentActionResponse openJobResponse = openJobListener.actionGet(); + assertNotNull(openJobResponse); + } + + // Pick a random license + License.OperationMode mode = randomLicenseType(); + enableLicensing(mode); + + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).closeJob(new CloseJobAction.Request("foo"), listener); + listener.actionGet(); + } + } + + public void testMachineLearningDeleteJobActionNotRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + } + + // Pick a random license + License.OperationMode mode = randomLicenseType(); + enableLicensing(mode); + + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).deleteJob(new DeleteJobAction.Request("foo"), listener); + listener.actionGet(); + } + } + + public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception { + + assertMLAllowed(true); + // test that license restricted apis do now work + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture putJobListener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener); + PutJobAction.Response putJobResponse = putJobListener.actionGet(); + assertNotNull(putJobResponse); + PlainListenableActionFuture putDatafeedListener = new PlainListenableActionFuture<>( + client.threadPool()); + new MachineLearningClient(client).putDatafeed( + new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener); + PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet(); + assertNotNull(putDatafeedResponse); + } + + // Pick a random license + License.OperationMode mode = randomLicenseType(); + enableLicensing(mode); + + try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) { + client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress()); + PlainListenableActionFuture listener = new PlainListenableActionFuture<>(client.threadPool()); + new MachineLearningClient(client).deleteDatafeed(new DeleteDatafeedAction.Request("foobar"), listener); + listener.actionGet(); + } + } + + private static OperationMode randomInvalidLicenseType() { + return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC); + } + + private static OperationMode randomValidLicenseType() { + return randomFrom(License.OperationMode.TRIAL, License.OperationMode.PLATINUM); + } + + private static OperationMode randomLicenseType() { + return randomFrom(License.OperationMode.values()); + } + + private static void assertMLAllowed(boolean expected) { + for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) { + assertEquals(licenseState.isMachineLearningAllowed(), expected); + } + } + + public static void disableLicensing() { + disableLicensing(randomValidLicenseType()); + } + + public static void disableLicensing(License.OperationMode operationMode) { + for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) { + licenseState.update(operationMode, false); + } + } + + public static void enableLicensing() { + enableLicensing(randomValidLicenseType()); + } + + public static void enableLicensing(License.OperationMode operationMode) { + for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) { + licenseState.update(operationMode, true); + } + } +} diff --git a/plugin/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/plugin/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index 0715962f2fc..dab6a75a1aa 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -60,6 +60,13 @@ public abstract class BaseMlIntegTestCase extends SecurityIntegTestCase { return settings.build(); } + @Override + protected Settings transportClientSettings() { + Settings.Builder settings = Settings.builder().put(super.transportClientSettings()); + settings.put(XPackSettings.MACHINE_LEARNING_ENABLED.getKey(), true); + return settings.build(); + } + protected Job.Builder createJob(String id) { DataDescription.Builder dataDescription = new DataDescription.Builder(); dataDescription.setFormat(DataDescription.DataFormat.JSON);