Add support for the task management framework

These changes are required to support elastic/elasticsearchelastic/elasticsearch#15347

Original commit: elastic/x-pack-elasticsearch@37adf4fc83
This commit is contained in:
Igor Motov 2016-01-05 10:07:54 -05:00
parent ebcca4f3c2
commit 852aac0b9c
9 changed files with 45 additions and 23 deletions

View File

@ -26,6 +26,7 @@ import org.elasticsearch.shield.authz.AuthorizationService;
import org.elasticsearch.shield.authz.Privilege; import org.elasticsearch.shield.authz.Privilege;
import org.elasticsearch.shield.crypto.CryptoService; import org.elasticsearch.shield.crypto.CryptoService;
import org.elasticsearch.shield.license.ShieldLicenseState; import org.elasticsearch.shield.license.ShieldLicenseState;
import org.elasticsearch.tasks.Task;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -64,7 +65,7 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
} }
@Override @Override
public void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { public void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
/** /**
A functional requirement - when the license of shield is disabled (invalid/expires), shield will continue A functional requirement - when the license of shield is disabled (invalid/expires), shield will continue
@ -100,9 +101,9 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
interceptor.intercept(request, user); interceptor.intercept(request, user);
} }
} }
chain.proceed(action, request, new SigningListener(this, listener)); chain.proceed(task, action, request, new SigningListener(this, listener));
} else { } else {
chain.proceed(action, request, listener); chain.proceed(task, action, request, listener);
} }
} catch (Throwable t) { } catch (Throwable t) {
listener.onFailure(t); listener.onFailure(t);

View File

@ -12,6 +12,7 @@ import org.elasticsearch.shield.action.ShieldActionMapper;
import org.elasticsearch.shield.authc.AuthenticationService; import org.elasticsearch.shield.authc.AuthenticationService;
import org.elasticsearch.shield.authc.pki.PkiRealm; import org.elasticsearch.shield.authc.pki.PkiRealm;
import org.elasticsearch.shield.authz.AuthorizationService; import org.elasticsearch.shield.authz.AuthorizationService;
import org.elasticsearch.transport.DelegatingTransportChannel;
import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.netty.NettyTransportChannel; import org.elasticsearch.transport.netty.NettyTransportChannel;
@ -71,8 +72,13 @@ public interface ServerTransportFilter {
*/ */
String shieldAction = actionMapper.action(action, request); String shieldAction = actionMapper.action(action, request);
if (extractClientCert && (transportChannel instanceof NettyTransportChannel)) { TransportChannel unwrappedChannel = transportChannel;
Channel channel = ((NettyTransportChannel)transportChannel).getChannel(); while (unwrappedChannel instanceof DelegatingTransportChannel) {
unwrappedChannel = ((DelegatingTransportChannel) unwrappedChannel).getChannel();
}
if (extractClientCert && (unwrappedChannel instanceof NettyTransportChannel)) {
Channel channel = ((NettyTransportChannel)unwrappedChannel).getChannel();
SslHandler sslHandler = channel.getPipeline().get(SslHandler.class); SslHandler sslHandler = channel.getPipeline().get(SslHandler.class);
assert sslHandler != null; assert sslHandler != null;

View File

@ -14,6 +14,7 @@ import org.elasticsearch.shield.authz.AuthorizationService;
import org.elasticsearch.shield.authz.accesscontrol.RequestContext; import org.elasticsearch.shield.authz.accesscontrol.RequestContext;
import org.elasticsearch.shield.license.ShieldLicenseState; import org.elasticsearch.shield.license.ShieldLicenseState;
import org.elasticsearch.shield.transport.netty.ShieldNettyTransport; import org.elasticsearch.shield.transport.netty.ShieldNettyTransport;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannel;
@ -144,7 +145,7 @@ public class ShieldServerTransportService extends TransportService {
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void messageReceived(T request, TransportChannel channel) throws Exception { public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
try { try {
if (licenseState.securityEnabled()) { if (licenseState.securityEnabled()) {
String profile = channel.getProfileName(); String profile = channel.getProfileName();
@ -163,13 +164,18 @@ public class ShieldServerTransportService extends TransportService {
} }
RequestContext context = new RequestContext(request); RequestContext context = new RequestContext(request);
RequestContext.setCurrent(context); RequestContext.setCurrent(context);
handler.messageReceived(request, channel); handler.messageReceived(request, channel, task);
} catch (Throwable t) { } catch (Throwable t) {
channel.sendResponse(t); channel.sendResponse(t);
} finally { } finally {
RequestContext.removeCurrent(); RequestContext.removeCurrent();
} }
} }
@Override
public void messageReceived(T request, TransportChannel channel) throws Exception {
throw new UnsupportedOperationException("task parameter is required for this operation");
}
} }
} }

View File

@ -18,6 +18,7 @@ import org.elasticsearch.shield.authc.AuthenticationService;
import org.elasticsearch.shield.authz.AuthorizationService; import org.elasticsearch.shield.authz.AuthorizationService;
import org.elasticsearch.shield.crypto.CryptoService; import org.elasticsearch.shield.crypto.CryptoService;
import org.elasticsearch.shield.license.ShieldLicenseState; import org.elasticsearch.shield.license.ShieldLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.Before; import org.junit.Before;
@ -62,12 +63,13 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionRequest request = mock(ActionRequest.class); ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class); ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class); ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
User user = new User.Simple("username", new String[] { "r1", "r2" }); User user = new User.Simple("username", new String[] { "r1", "r2" });
when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user); when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user);
doReturn(request).when(spy(filter)).unsign(user, "_action", request); doReturn(request).when(spy(filter)).unsign(user, "_action", request);
filter.apply("_action", request, listener, chain); filter.apply(task, "_action", request, listener, chain);
verify(authzService).authorize(user, "_action", request); verify(authzService).authorize(user, "_action", request);
verify(chain).proceed(eq("_action"), eq(request), isA(ShieldActionFilter.SigningListener.class)); verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ShieldActionFilter.SigningListener.class));
} }
public void testActionProcessException() throws Exception { public void testActionProcessException() throws Exception {
@ -75,10 +77,11 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionListener listener = mock(ActionListener.class); ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class); ActionFilterChain chain = mock(ActionFilterChain.class);
RuntimeException exception = new RuntimeException("process-error"); RuntimeException exception = new RuntimeException("process-error");
Task task = mock(Task.class);
User user = new User.Simple("username", new String[] { "r1", "r2" }); User user = new User.Simple("username", new String[] { "r1", "r2" });
when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user); when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user);
doThrow(exception).when(authzService).authorize(user, "_action", request); doThrow(exception).when(authzService).authorize(user, "_action", request);
filter.apply("_action", request, listener, chain); filter.apply(task, "_action", request, listener, chain);
verify(listener).onFailure(exception); verify(listener).onFailure(exception);
verifyNoMoreInteractions(chain); verifyNoMoreInteractions(chain);
} }
@ -88,13 +91,14 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionListener listener = mock(ActionListener.class); ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class); ActionFilterChain chain = mock(ActionFilterChain.class);
User user = mock(User.class); User user = mock(User.class);
Task task = mock(Task.class);
when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user); when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user);
when(cryptoService.signed("signed_scroll_id")).thenReturn(true); when(cryptoService.signed("signed_scroll_id")).thenReturn(true);
when(cryptoService.unsignAndVerify("signed_scroll_id")).thenReturn("scroll_id"); when(cryptoService.unsignAndVerify("signed_scroll_id")).thenReturn("scroll_id");
filter.apply("_action", request, listener, chain); filter.apply(task, "_action", request, listener, chain);
assertThat(request.scrollId(), equalTo("scroll_id")); assertThat(request.scrollId(), equalTo("scroll_id"));
verify(authzService).authorize(user, "_action", request); verify(authzService).authorize(user, "_action", request);
verify(chain).proceed(eq("_action"), eq(request), isA(ShieldActionFilter.SigningListener.class)); verify(chain).proceed(eq(task), eq("_action"), eq(request), isA(ShieldActionFilter.SigningListener.class));
} }
public void testActionSignatureError() throws Exception { public void testActionSignatureError() throws Exception {
@ -103,10 +107,11 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionFilterChain chain = mock(ActionFilterChain.class); ActionFilterChain chain = mock(ActionFilterChain.class);
IllegalArgumentException sigException = new IllegalArgumentException("bad bad boy"); IllegalArgumentException sigException = new IllegalArgumentException("bad bad boy");
User user = mock(User.class); User user = mock(User.class);
Task task = mock(Task.class);
when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user); when(authcService.authenticate("_action", request, User.SYSTEM)).thenReturn(user);
when(cryptoService.signed("scroll_id")).thenReturn(true); when(cryptoService.signed("scroll_id")).thenReturn(true);
doThrow(sigException).when(cryptoService).unsignAndVerify("scroll_id"); doThrow(sigException).when(cryptoService).unsignAndVerify("scroll_id");
filter.apply("_action", request, listener, chain); filter.apply(task, "_action", request, listener, chain);
verify(listener).onFailure(isA(ElasticsearchSecurityException.class)); verify(listener).onFailure(isA(ElasticsearchSecurityException.class));
verify(auditTrail).tamperedRequest(user, "_action", request); verify(auditTrail).tamperedRequest(user, "_action", request);
verifyNoMoreInteractions(chain); verifyNoMoreInteractions(chain);
@ -116,11 +121,12 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionRequest request = mock(ActionRequest.class); ActionRequest request = mock(ActionRequest.class);
ActionListener listener = mock(ActionListener.class); ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class); ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
when(shieldLicenseState.securityEnabled()).thenReturn(false); when(shieldLicenseState.securityEnabled()).thenReturn(false);
filter.apply("_action", request, listener, chain); filter.apply(task, "_action", request, listener, chain);
verifyZeroInteractions(authcService); verifyZeroInteractions(authcService);
verifyZeroInteractions(authzService); verifyZeroInteractions(authzService);
verify(chain).proceed(eq("_action"), eq(request), eq(listener)); verify(chain).proceed(eq(task), eq("_action"), eq(request), eq(listener));
} }
} }

View File

@ -31,7 +31,6 @@ import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler; import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.netty.NettyTransport; import org.elasticsearch.transport.netty.NettyTransport;
import org.elasticsearch.transport.netty.NettyTransportChannel;
import org.mockito.InOrder; import org.mockito.InOrder;
import java.io.IOException; import java.io.IOException;
@ -102,11 +101,11 @@ public class TransportFilterTests extends ESIntegTestCase {
ClientTransportFilter sourceClientFilter = internalCluster().getInstance(ClientTransportFilter.class, source); ClientTransportFilter sourceClientFilter = internalCluster().getInstance(ClientTransportFilter.class, source);
ClientTransportFilter targetClientFilter = internalCluster().getInstance(ClientTransportFilter.class, target); ClientTransportFilter targetClientFilter = internalCluster().getInstance(ClientTransportFilter.class, target);
InOrder inOrder = inOrder(sourceServerFilter, sourceClientFilter, targetServerFilter, targetClientFilter); InOrder inOrder = inOrder(sourceClientFilter, targetServerFilter, targetClientFilter, sourceServerFilter);
inOrder.verify(sourceClientFilter).outbound("_action", new Request("src_to_trgt")); inOrder.verify(sourceClientFilter).outbound("_action", new Request("src_to_trgt"));
inOrder.verify(targetServerFilter).inbound(eq("_action"), eq(new Request("src_to_trgt")), isA(NettyTransportChannel.class)); inOrder.verify(targetServerFilter).inbound(eq("_action"), eq(new Request("src_to_trgt")), isA(TransportChannel.class));
inOrder.verify(targetClientFilter).outbound("_action", new Request("trgt_to_src")); inOrder.verify(targetClientFilter).outbound("_action", new Request("trgt_to_src"));
inOrder.verify(sourceServerFilter).inbound(eq("_action"), eq(new Request("trgt_to_src")), isA(NettyTransportChannel.class)); inOrder.verify(sourceServerFilter).inbound(eq("_action"), eq(new Request("trgt_to_src")), isA(TransportChannel.class));
} }
public static class InternalPlugin extends Plugin { public static class InternalPlugin extends Plugin {

View File

@ -13,6 +13,7 @@ import org.elasticsearch.test.ShieldIntegTestCase;
import java.util.Map; import java.util.Map;
import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
// this class sits in org.elasticsearch.transport so that TransportService.requestHandlers is visible // this class sits in org.elasticsearch.transport so that TransportService.requestHandlers is visible
public class ShieldServerTransportServiceTests extends ShieldIntegTestCase { public class ShieldServerTransportServiceTests extends ShieldIntegTestCase {
@ -30,8 +31,8 @@ public class ShieldServerTransportServiceTests extends ShieldIntegTestCase {
for (Map.Entry<String, RequestHandlerRegistry> entry : transportService.requestHandlers.entrySet()) { for (Map.Entry<String, RequestHandlerRegistry> entry : transportService.requestHandlers.entrySet()) {
assertThat( assertThat(
"handler not wrapped by " + ShieldServerTransportService.ProfileSecuredRequestHandler.class + "; do all the handler registration methods have overrides?", "handler not wrapped by " + ShieldServerTransportService.ProfileSecuredRequestHandler.class + "; do all the handler registration methods have overrides?",
entry.getValue().getHandler(), entry.getValue().toString(),
instanceOf(ShieldServerTransportService.ProfileSecuredRequestHandler.class) startsWith(ShieldServerTransportService.ProfileSecuredRequestHandler.class.getName() + "@")
); );
} }
} }

View File

@ -17,6 +17,7 @@ cluster:monitor/nodes/stats
cluster:monitor/state cluster:monitor/state
cluster:monitor/stats cluster:monitor/stats
cluster:monitor/task cluster:monitor/task
cluster:monitor/tasks/lists
indices:admin/aliases indices:admin/aliases
indices:admin/aliases/exists indices:admin/aliases/exists
indices:admin/aliases/get indices:admin/aliases/get

View File

@ -6,6 +6,7 @@ cluster:monitor/nodes/info[n]
cluster:monitor/nodes/liveness cluster:monitor/nodes/liveness
cluster:monitor/nodes/stats[n] cluster:monitor/nodes/stats[n]
cluster:monitor/stats[n] cluster:monitor/stats[n]
cluster:monitor/tasks/lists[n]
cluster:admin/shield/realm/cache/clear cluster:admin/shield/realm/cache/clear
cluster:admin/shield/realm/cache/clear[n] cluster:admin/shield/realm/cache/clear[n]
indices:admin/analyze[s] indices:admin/analyze[s]

View File

@ -14,6 +14,7 @@ import org.elasticsearch.cluster.ClusterService;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.plugin.core.LicenseUtils; import org.elasticsearch.license.plugin.core.LicenseUtils;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
import org.elasticsearch.watcher.license.WatcherLicensee; import org.elasticsearch.watcher.license.WatcherLicensee;
@ -35,9 +36,9 @@ public abstract class WatcherTransportAction<Request extends MasterNodeRequest<R
} }
@Override @Override
protected void doExecute(Request request, ActionListener<Response> listener) { protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
if (watcherLicensee.isWatcherTransportActionAllowed()) { if (watcherLicensee.isWatcherTransportActionAllowed()) {
super.doExecute(request, listener); super.doExecute(task, request, listener);
} else { } else {
listener.onFailure(LicenseUtils.newComplianceException(WatcherLicensee.ID)); listener.onFailure(LicenseUtils.newComplianceException(WatcherLicensee.ID));
} }