Restricts certain ml endpoints if license forbids ml (elastic/x-pack-elasticsearch#568)

This change adds checks to the following machine learning actions to check that machine learning is permitted by the license before executing the request and throw and exception if the license forbids it:

* Create job
* Create data feed
* Open job
* Start data feed

This is not a final list of the restrictions we want to have in place if the license forbids ML access but it serves as a starting point and can be easily updated following further discussions.

The change also moves the `transportClientMode` check to `createComponents` so we don’t try to start the server parts of machine learning (job manager, connection to named pipes etc.) if we are running in a transport client.

Original commit: elastic/x-pack-elasticsearch@6c19ebd3bc
This commit is contained in:
Colin Goodheart-Smithe 2017-02-16 09:40:53 +00:00 committed by GitHub
parent 54d57f6398
commit 19fc532961
9 changed files with 613 additions and 27 deletions

View File

@ -230,7 +230,7 @@ public class XPackPlugin extends Plugin implements ScriptPlugin, ActionPlugin, I
modules.addAll(monitoring.nodeModules());
modules.addAll(watcher.nodeModules());
modules.addAll(graph.createGuiceModules());
modules.addAll(machineLearning.createGuiceModules());
modules.addAll(machineLearning.nodeModules());
if (transportClientMode) {
modules.add(b -> b.bind(XPackLicenseState.class).toProvider(Providers.of(null)));

View File

@ -18,7 +18,6 @@ import org.elasticsearch.cluster.routing.UnassignedInfo;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.inject.Module;
import org.elasticsearch.common.inject.util.Providers;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.IndexScopedSettings;
@ -132,8 +131,6 @@ import org.elasticsearch.xpack.persistent.PersistentTasksInProgress;
import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction;
import org.elasticsearch.xpack.persistent.StartPersistentTaskAction;
import org.elasticsearch.xpack.persistent.UpdatePersistentTaskStatusAction;
import org.elasticsearch.xpack.watcher.WatcherFeatureSet;
import org.elasticsearch.xpack.watcher.WatcherService;
import java.io.IOException;
import java.util.ArrayList;
@ -240,7 +237,7 @@ public class MachineLearning extends Plugin implements ActionPlugin {
public Collection<Object> createComponents(Client client, ClusterService clusterService, ThreadPool threadPool,
ResourceWatcherService resourceWatcherService, ScriptService scriptService,
NamedXContentRegistry xContentRegistry) {
if (false == enabled) {
if (false == enabled || this.transportClientMode) {
return emptyList();
}
// Whether we are using native process is a good way to detect whether we are in dev / test mode:
@ -298,10 +295,7 @@ public class MachineLearning extends Plugin implements ActionPlugin {
public Collection<Module> nodeModules() {
List<Module> modules = new ArrayList<>();
modules.add(b -> {
XPackPlugin.bindFeatureSet(b, WatcherFeatureSet.class);
if (transportClientMode || enabled == false) {
b.bind(WatcherService.class).toProvider(Providers.of(null));
}
XPackPlugin.bindFeatureSet(b, MachineLearningFeatureSet.class);
});
return modules;

View File

@ -30,12 +30,15 @@ import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.config.Job;
import org.elasticsearch.xpack.ml.job.config.JobState;
@ -242,14 +245,16 @@ public class OpenJobAction extends Action<OpenJobAction.Request, PersistentActio
private final JobStateObserver observer;
private final ClusterService clusterService;
private final AutodetectProcessManager autodetectProcessManager;
private XPackLicenseState licenseState;
@Inject
public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool,
public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool, XPackLicenseState licenseState,
PersistentActionService persistentActionService, PersistentActionRegistry persistentActionRegistry,
ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver,
ClusterService clusterService, AutodetectProcessManager autodetectProcessManager) {
super(settings, OpenJobAction.NAME, false, threadPool, transportService, persistentActionService,
persistentActionRegistry, actionFilters, indexNameExpressionResolver, Request::new, ThreadPool.Names.MANAGEMENT);
this.licenseState = licenseState;
this.clusterService = clusterService;
this.autodetectProcessManager = autodetectProcessManager;
this.observer = new JobStateObserver(threadPool, clusterService);
@ -257,17 +262,21 @@ public class OpenJobAction extends Action<OpenJobAction.Request, PersistentActio
@Override
protected void doExecute(Request request, ActionListener<PersistentActionResponse> listener) {
// If we already know that we can't find an ml node because all ml nodes are running at capacity or
// simply because there are no ml nodes in the cluster then we fail quickly here:
ClusterState clusterState = clusterService.state();
if (selectLeastLoadedMlNode(request.getJobId(), clusterState, logger) == null) {
throw new ElasticsearchStatusException("no nodes available to open job [" + request.getJobId() + "]",
RestStatus.TOO_MANY_REQUESTS);
}
if (licenseState.isMachineLearningAllowed()) {
// If we already know that we can't find an ml node because all ml nodes are running at capacity or
// simply because there are no ml nodes in the cluster then we fail quickly here:
ClusterState clusterState = clusterService.state();
if (selectLeastLoadedMlNode(request.getJobId(), clusterState, logger) == null) {
throw new ElasticsearchStatusException("no nodes available to open job [" + request.getJobId() + "]",
RestStatus.TOO_MANY_REQUESTS);
}
ActionListener<PersistentActionResponse> finalListener =
ActionListener.wrap(response -> waitForJobStarted(request, response, listener), listener::onFailure);
super.doExecute(request, finalListener);
ActionListener<PersistentActionResponse> finalListener =
ActionListener.wrap(response -> waitForJobStarted(request, response, listener), listener::onFailure);
super.doExecute(request, finalListener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING));
}
}
void waitForJobStarted(Request request, PersistentActionResponse response, ActionListener<PersistentActionResponse> listener) {

View File

@ -29,10 +29,14 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.ml.job.metadata.MlMetadata;
import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.ml.job.metadata.MlMetadata;
import java.io.IOException;
import java.util.Objects;
@ -175,12 +179,15 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
public static class TransportAction extends TransportMasterNodeAction<Request, Response> {
private XPackLicenseState licenseState;
@Inject
public TransportAction(Settings settings, TransportService transportService, ClusterService clusterService,
ThreadPool threadPool, ActionFilters actionFilters,
ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver) {
super(settings, PutDatafeedAction.NAME, transportService, clusterService, threadPool, actionFilters,
indexNameExpressionResolver, Request::new);
this.licenseState = licenseState;
}
@Override
@ -226,5 +233,14 @@ public class PutDatafeedAction extends Action<PutDatafeedAction.Request, PutData
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
}
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
if (licenseState.isMachineLearningAllowed()) {
super.doExecute(task, request, listener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING));
}
}
}
}

View File

@ -28,8 +28,12 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.ml.job.JobManager;
import org.elasticsearch.xpack.ml.job.config.Job;
@ -180,13 +184,15 @@ public class PutJobAction extends Action<PutJobAction.Request, PutJobAction.Resp
public static class TransportAction extends TransportMasterNodeAction<Request, Response> {
private final JobManager jobManager;
private XPackLicenseState licenseState;
@Inject
public TransportAction(Settings settings, TransportService transportService, ClusterService clusterService,
ThreadPool threadPool, ActionFilters actionFilters,
ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver, JobManager jobManager) {
super(settings, PutJobAction.NAME, transportService, clusterService, threadPool, actionFilters,
indexNameExpressionResolver, Request::new);
this.licenseState = licenseState;
this.jobManager = jobManager;
}
@ -209,5 +215,14 @@ public class PutJobAction extends Action<PutJobAction.Request, PutJobAction.Resp
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
}
@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
if (licenseState.isMachineLearningAllowed()) {
super.doExecute(task, request, listener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING));
}
}
}
}

View File

@ -28,12 +28,15 @@ import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.ml.datafeed.DatafeedJobRunner;
import org.elasticsearch.xpack.ml.datafeed.DatafeedJobValidator;
@ -244,23 +247,29 @@ public class StartDatafeedAction
private final DatafeedStateObserver observer;
private final DatafeedJobRunner datafeedJobRunner;
private XPackLicenseState licenseState;
@Inject
public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool,
public TransportAction(Settings settings, TransportService transportService, ThreadPool threadPool, XPackLicenseState licenseState,
PersistentActionService persistentActionService, PersistentActionRegistry persistentActionRegistry,
ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver,
ClusterService clusterService, DatafeedJobRunner datafeedJobRunner) {
super(settings, NAME, false, threadPool, transportService, persistentActionService, persistentActionRegistry,
actionFilters, indexNameExpressionResolver, Request::new, ThreadPool.Names.MANAGEMENT);
this.licenseState = licenseState;
this.datafeedJobRunner = datafeedJobRunner;
this.observer = new DatafeedStateObserver(threadPool, clusterService);
}
@Override
protected void doExecute(Request request, ActionListener<PersistentActionResponse> listener) {
ActionListener<PersistentActionResponse> finalListener =
ActionListener.wrap(response -> waitForDatafeedStarted(request, response, listener), listener::onFailure);
super.doExecute(request, finalListener);
if (licenseState.isMachineLearningAllowed()) {
ActionListener<PersistentActionResponse> finalListener = ActionListener
.wrap(response -> waitForDatafeedStarted(request, response, listener), listener::onFailure);
super.doExecute(request, finalListener);
} else {
listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.MACHINE_LEARNING));
}
}
void waitForDatafeedStarted(Request request,

View File

@ -0,0 +1,158 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.client;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.xpack.ml.action.CloseJobAction;
import org.elasticsearch.xpack.ml.action.DeleteDatafeedAction;
import org.elasticsearch.xpack.ml.action.DeleteFilterAction;
import org.elasticsearch.xpack.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.ml.action.DeleteModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.FlushJobAction;
import org.elasticsearch.xpack.ml.action.GetBucketsAction;
import org.elasticsearch.xpack.ml.action.GetCategoriesAction;
import org.elasticsearch.xpack.ml.action.GetDatafeedsAction;
import org.elasticsearch.xpack.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.ml.action.GetFiltersAction;
import org.elasticsearch.xpack.ml.action.GetInfluencersAction;
import org.elasticsearch.xpack.ml.action.GetJobsAction;
import org.elasticsearch.xpack.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.ml.action.GetModelSnapshotsAction;
import org.elasticsearch.xpack.ml.action.GetRecordsAction;
import org.elasticsearch.xpack.ml.action.OpenJobAction;
import org.elasticsearch.xpack.ml.action.PostDataAction;
import org.elasticsearch.xpack.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.ml.action.PutFilterAction;
import org.elasticsearch.xpack.ml.action.PutJobAction;
import org.elasticsearch.xpack.ml.action.RevertModelSnapshotAction;
import org.elasticsearch.xpack.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.ml.action.UpdateDatafeedAction;
import org.elasticsearch.xpack.ml.action.UpdateJobAction;
import org.elasticsearch.xpack.ml.action.UpdateModelSnapshotAction;
import org.elasticsearch.xpack.persistent.PersistentActionResponse;
import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction;
public class MachineLearningClient {
private final ElasticsearchClient client;
public MachineLearningClient(ElasticsearchClient client) {
this.client = client;
}
public void closeJob(CloseJobAction.Request request, ActionListener<CloseJobAction.Response> listener) {
client.execute(CloseJobAction.INSTANCE, request, listener);
}
public void deleteDatafeed(DeleteDatafeedAction.Request request, ActionListener<DeleteDatafeedAction.Response> listener) {
client.execute(DeleteDatafeedAction.INSTANCE, request, listener);
}
public void deleteFilter(DeleteFilterAction.Request request, ActionListener<DeleteFilterAction.Response> listener) {
client.execute(DeleteFilterAction.INSTANCE, request, listener);
}
public void deleteJob(DeleteJobAction.Request request, ActionListener<DeleteJobAction.Response> listener) {
client.execute(DeleteJobAction.INSTANCE, request, listener);
}
public void deleteModelSnapshot(DeleteModelSnapshotAction.Request request,
ActionListener<DeleteModelSnapshotAction.Response> listener) {
client.execute(DeleteModelSnapshotAction.INSTANCE, request, listener);
}
public void flushJob(FlushJobAction.Request request, ActionListener<FlushJobAction.Response> listener) {
client.execute(FlushJobAction.INSTANCE, request, listener);
}
public void getBuckets(GetBucketsAction.Request request, ActionListener<GetBucketsAction.Response> listener) {
client.execute(GetBucketsAction.INSTANCE, request, listener);
}
public void getCategories(GetCategoriesAction.Request request, ActionListener<GetCategoriesAction.Response> listener) {
client.execute(GetCategoriesAction.INSTANCE, request, listener);
}
public void getDatafeeds(GetDatafeedsAction.Request request, ActionListener<GetDatafeedsAction.Response> listener) {
client.execute(GetDatafeedsAction.INSTANCE, request, listener);
}
public void getDatafeedsStats(GetDatafeedsStatsAction.Request request, ActionListener<GetDatafeedsStatsAction.Response> listener) {
client.execute(GetDatafeedsStatsAction.INSTANCE, request, listener);
}
public void getFilters(GetFiltersAction.Request request, ActionListener<GetFiltersAction.Response> listener) {
client.execute(GetFiltersAction.INSTANCE, request, listener);
}
public void getInfluencers(GetInfluencersAction.Request request, ActionListener<GetInfluencersAction.Response> listener) {
client.execute(GetInfluencersAction.INSTANCE, request, listener);
}
public void getJobs(GetJobsAction.Request request, ActionListener<GetJobsAction.Response> listener) {
client.execute(GetJobsAction.INSTANCE, request, listener);
}
public void getJobsStats(GetJobsStatsAction.Request request, ActionListener<GetJobsStatsAction.Response> listener) {
client.execute(GetJobsStatsAction.INSTANCE, request, listener);
}
public void getModelSnapshots(GetModelSnapshotsAction.Request request, ActionListener<GetModelSnapshotsAction.Response> listener) {
client.execute(GetModelSnapshotsAction.INSTANCE, request, listener);
}
public void getRecords(GetRecordsAction.Request request, ActionListener<GetRecordsAction.Response> listener) {
client.execute(GetRecordsAction.INSTANCE, request, listener);
}
public void openJob(OpenJobAction.Request request, ActionListener<PersistentActionResponse> listener) {
client.execute(OpenJobAction.INSTANCE, request, listener);
}
public void postData(PostDataAction.Request request, ActionListener<PostDataAction.Response> listener) {
client.execute(PostDataAction.INSTANCE, request, listener);
}
public void putDatafeed(PutDatafeedAction.Request request, ActionListener<PutDatafeedAction.Response> listener) {
client.execute(PutDatafeedAction.INSTANCE, request, listener);
}
public void putFilter(PutFilterAction.Request request, ActionListener<PutFilterAction.Response> listener) {
client.execute(PutFilterAction.INSTANCE, request, listener);
}
public void putJob(PutJobAction.Request request, ActionListener<PutJobAction.Response> listener) {
client.execute(PutJobAction.INSTANCE, request, listener);
}
public void revertModelSnapshot(RevertModelSnapshotAction.Request request,
ActionListener<RevertModelSnapshotAction.Response> listener) {
client.execute(RevertModelSnapshotAction.INSTANCE, request, listener);
}
public void startDatafeed(StartDatafeedAction.Request request, ActionListener<PersistentActionResponse> listener) {
client.execute(StartDatafeedAction.INSTANCE, request, listener);
}
public void stopDatafeed(StopDatafeedAction.Request request, ActionListener<RemovePersistentTaskAction.Response> listener) {
client.execute(StopDatafeedAction.INSTANCE, request, listener);
}
public void updateDatafeed(UpdateDatafeedAction.Request request, ActionListener<PutDatafeedAction.Response> listener) {
client.execute(UpdateDatafeedAction.INSTANCE, request, listener);
}
public void updateJob(UpdateJobAction.Request request, ActionListener<PutJobAction.Response> listener) {
client.execute(UpdateJobAction.INSTANCE, request, listener);
}
public void updateModelSnapshot(UpdateModelSnapshotAction.Request request,
ActionListener<UpdateModelSnapshotAction.Response> listener) {
client.execute(UpdateModelSnapshotAction.INSTANCE, request, listener);
}
}

View File

@ -0,0 +1,378 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.license;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.support.PlainListenableActionFuture;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.xpack.TestXPackTransportClient;
import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.ml.action.CloseJobAction;
import org.elasticsearch.xpack.ml.action.DeleteDatafeedAction;
import org.elasticsearch.xpack.ml.action.DeleteJobAction;
import org.elasticsearch.xpack.ml.action.OpenJobAction;
import org.elasticsearch.xpack.ml.action.PutDatafeedAction;
import org.elasticsearch.xpack.ml.action.PutJobAction;
import org.elasticsearch.xpack.ml.action.StartDatafeedAction;
import org.elasticsearch.xpack.ml.action.StopDatafeedAction;
import org.elasticsearch.xpack.ml.client.MachineLearningClient;
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
import org.elasticsearch.xpack.persistent.PersistentActionResponse;
import org.elasticsearch.xpack.persistent.RemovePersistentTaskAction;
import org.junit.Before;
import java.util.Collections;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
public class MachineLearningLicensingTests extends BaseMlIntegTestCase {
@Before
public void resetLicensing() {
enableLicensing();
ensureStableCluster(1);
ensureYellow();
}
public void testMachineLearningPutJobActionRestricted() throws Exception {
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// test that license restricted apis do not work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), listener);
listener.actionGet();
fail("put job action should not be enabled!");
} catch (ElasticsearchSecurityException e) {
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING));
}
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), listener);
PutJobAction.Response response = listener.actionGet();
assertNotNull(response);
}
}
public void testMachineLearningOpenJobActionRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response response = putJobListener.actionGet();
assertNotNull(response);
}
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// test that license restricted apis do not work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PersistentActionResponse> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), listener);
listener.actionGet();
fail("open job action should not be enabled!");
} catch (ElasticsearchSecurityException e) {
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING));
}
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PersistentActionResponse> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), listener);
PersistentActionResponse response = listener.actionGet();
assertNotNull(response);
}
}
public void testMachineLearningPutDatafeedActionRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
}
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// test that license restricted apis do not work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutDatafeedAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putDatafeed(
new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), listener);
listener.actionGet();
fail("put datafeed action should not be enabled!");
} catch (ElasticsearchSecurityException e) {
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING));
}
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutDatafeedAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putDatafeed(
new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), listener);
PutDatafeedAction.Response response = listener.actionGet();
assertNotNull(response);
}
}
public void testMachineLearningStartDatafeedActionRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
PlainListenableActionFuture<PutDatafeedAction.Response> putDatafeedListener = new PlainListenableActionFuture<>(
client.threadPool());
new MachineLearningClient(client).putDatafeed(
new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener);
PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet();
assertNotNull(putDatafeedResponse);
PlainListenableActionFuture<PersistentActionResponse> openJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener);
PersistentActionResponse openJobResponse = openJobListener.actionGet();
assertNotNull(openJobResponse);
}
// Pick a license that does not allow machine learning
License.OperationMode mode = randomInvalidLicenseType();
enableLicensing(mode);
assertMLAllowed(false);
// test that license restricted apis do not work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PersistentActionResponse> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), listener);
listener.actionGet();
fail("start datafeed action should not be enabled!");
} catch (ElasticsearchSecurityException e) {
assertThat(e.status(), is(RestStatus.FORBIDDEN));
assertThat(e.getMessage(), containsString("non-compliant"));
assertThat(e.getMetadata(LicenseUtils.EXPIRED_FEATURE_METADATA), hasItem(XPackPlugin.MACHINE_LEARNING));
}
// Pick a license that does allow machine learning
mode = randomValidLicenseType();
enableLicensing(mode);
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PersistentActionResponse> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), listener);
PersistentActionResponse response = listener.actionGet();
assertNotNull(response);
}
}
public void testMachineLearningStopDatafeedActionNotRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
PlainListenableActionFuture<PutDatafeedAction.Response> putDatafeedListener = new PlainListenableActionFuture<>(
client.threadPool());
new MachineLearningClient(client).putDatafeed(
new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener);
PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet();
assertNotNull(putDatafeedResponse);
PlainListenableActionFuture<PersistentActionResponse> openJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener);
PersistentActionResponse openJobResponse = openJobListener.actionGet();
assertNotNull(openJobResponse);
PlainListenableActionFuture<PersistentActionResponse> startDatafeedListener = new PlainListenableActionFuture<>(
client.threadPool());
new MachineLearningClient(client).startDatafeed(new StartDatafeedAction.Request("foobar", 0L), startDatafeedListener);
PersistentActionResponse startDatafeedResponse = startDatafeedListener.actionGet();
assertNotNull(startDatafeedResponse);
}
// Pick a random license
License.OperationMode mode = randomLicenseType();
enableLicensing(mode);
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<RemovePersistentTaskAction.Response> listener = new PlainListenableActionFuture<>(
client.threadPool());
new MachineLearningClient(client).stopDatafeed(new StopDatafeedAction.Request("foobar"), listener);
listener.actionGet();
}
}
public void testMachineLearningCloseJobActionNotRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
PlainListenableActionFuture<PersistentActionResponse> openJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).openJob(new OpenJobAction.Request("foo"), openJobListener);
PersistentActionResponse openJobResponse = openJobListener.actionGet();
assertNotNull(openJobResponse);
}
// Pick a random license
License.OperationMode mode = randomLicenseType();
enableLicensing(mode);
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<CloseJobAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).closeJob(new CloseJobAction.Request("foo"), listener);
listener.actionGet();
}
}
public void testMachineLearningDeleteJobActionNotRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
}
// Pick a random license
License.OperationMode mode = randomLicenseType();
enableLicensing(mode);
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<DeleteJobAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).deleteJob(new DeleteJobAction.Request("foo"), listener);
listener.actionGet();
}
}
public void testMachineLearningDeleteDatafeedActionNotRestricted() throws Exception {
assertMLAllowed(true);
// test that license restricted apis do now work
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<PutJobAction.Response> putJobListener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).putJob(new PutJobAction.Request(createJob("foo").build(true, "foo")), putJobListener);
PutJobAction.Response putJobResponse = putJobListener.actionGet();
assertNotNull(putJobResponse);
PlainListenableActionFuture<PutDatafeedAction.Response> putDatafeedListener = new PlainListenableActionFuture<>(
client.threadPool());
new MachineLearningClient(client).putDatafeed(
new PutDatafeedAction.Request(createDatafeed("foobar", "foo", Collections.singletonList("foo"))), putDatafeedListener);
PutDatafeedAction.Response putDatafeedResponse = putDatafeedListener.actionGet();
assertNotNull(putDatafeedResponse);
}
// Pick a random license
License.OperationMode mode = randomLicenseType();
enableLicensing(mode);
try (TransportClient client = new TestXPackTransportClient(internalCluster().transportClient().settings())) {
client.addTransportAddress(internalCluster().getDataNodeInstance(Transport.class).boundAddress().publishAddress());
PlainListenableActionFuture<DeleteDatafeedAction.Response> listener = new PlainListenableActionFuture<>(client.threadPool());
new MachineLearningClient(client).deleteDatafeed(new DeleteDatafeedAction.Request("foobar"), listener);
listener.actionGet();
}
}
private static OperationMode randomInvalidLicenseType() {
return randomFrom(License.OperationMode.GOLD, License.OperationMode.STANDARD, License.OperationMode.BASIC);
}
private static OperationMode randomValidLicenseType() {
return randomFrom(License.OperationMode.TRIAL, License.OperationMode.PLATINUM);
}
private static OperationMode randomLicenseType() {
return randomFrom(License.OperationMode.values());
}
private static void assertMLAllowed(boolean expected) {
for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) {
assertEquals(licenseState.isMachineLearningAllowed(), expected);
}
}
public static void disableLicensing() {
disableLicensing(randomValidLicenseType());
}
public static void disableLicensing(License.OperationMode operationMode) {
for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) {
licenseState.update(operationMode, false);
}
}
public static void enableLicensing() {
enableLicensing(randomValidLicenseType());
}
public static void enableLicensing(License.OperationMode operationMode) {
for (XPackLicenseState licenseState : internalCluster().getInstances(XPackLicenseState.class)) {
licenseState.update(operationMode, true);
}
}
}

View File

@ -60,6 +60,13 @@ public abstract class BaseMlIntegTestCase extends SecurityIntegTestCase {
return settings.build();
}
@Override
protected Settings transportClientSettings() {
Settings.Builder settings = Settings.builder().put(super.transportClientSettings());
settings.put(XPackSettings.MACHINE_LEARNING_ENABLED.getKey(), true);
return settings.build();
}
protected Job.Builder createJob(String id) {
DataDescription.Builder dataDescription = new DataDescription.Builder();
dataDescription.setFormat(DataDescription.DataFormat.JSON);