shield: only restore the original if we forcefully replaced it

Original commit: elastic/x-pack-elasticsearch@347a4dba3f
This commit is contained in:
jaymode 2016-01-28 12:50:46 -05:00
parent 19545596cf
commit 1b4bac8203
2 changed files with 12 additions and 8 deletions

View File

@ -13,6 +13,7 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.component.AbstractComponent; import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -87,7 +88,6 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
throw LicenseUtils.newComplianceException(ShieldPlugin.NAME); throw LicenseUtils.newComplianceException(ShieldPlugin.NAME);
} }
final ThreadContext.StoredContext original = threadContext.newStoredContext();
try { try {
if (licenseState.securityEnabled()) { if (licenseState.securityEnabled()) {
// FIXME yet another hack. Needed to work around something like // FIXME yet another hack. Needed to work around something like
@ -121,6 +121,7 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
at java.lang.Thread.run(Thread.java:745) at java.lang.Thread.run(Thread.java:745)
*/ */
if (INTERNAL_PREDICATE.test(action)) { if (INTERNAL_PREDICATE.test(action)) {
final ThreadContext.StoredContext original = threadContext.newStoredContext();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
String shieldAction = actionMapper.action(action, request); String shieldAction = actionMapper.action(action, request);
User user = authcService.authenticate(shieldAction, request, User.SYSTEM); User user = authcService.authenticate(shieldAction, request, User.SYSTEM);
@ -159,12 +160,11 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
interceptor.intercept(request, user); interceptor.intercept(request, user);
} }
} }
chain.proceed(task, action, request, new SigningListener(this, listener, original)); chain.proceed(task, action, request, new SigningListener(this, listener, null));
} else { } else {
chain.proceed(task, action, request, listener); chain.proceed(task, action, request, listener);
} }
} catch (Throwable t) { } catch (Throwable t) {
original.restore();
listener.onFailure(t); listener.onFailure(t);
} }
} }
@ -232,7 +232,7 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
private final ActionListener innerListener; private final ActionListener innerListener;
private final ThreadContext.StoredContext threadContext; private final ThreadContext.StoredContext threadContext;
private SigningListener(ShieldActionFilter filter, ActionListener innerListener, ThreadContext.StoredContext threadContext) { private SigningListener(ShieldActionFilter filter, ActionListener innerListener, @Nullable ThreadContext.StoredContext threadContext) {
this.filter = filter; this.filter = filter;
this.innerListener = innerListener; this.innerListener = innerListener;
this.threadContext = threadContext; this.threadContext = threadContext;
@ -240,7 +240,9 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
@Override @SuppressWarnings("unchecked") @Override @SuppressWarnings("unchecked")
public void onResponse(Response response) { public void onResponse(Response response) {
threadContext.restore(); if (threadContext != null) {
threadContext.restore();
}
try { try {
response = this.filter.sign(response); response = this.filter.sign(response);
innerListener.onResponse(response); innerListener.onResponse(response);
@ -251,7 +253,9 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
@Override @Override
public void onFailure(Throwable e) { public void onFailure(Throwable e) {
threadContext.restore(); if (threadContext != null) {
threadContext.restore();
}
innerListener.onFailure(e); innerListener.onFailure(e);
} }
} }

View File

@ -78,11 +78,11 @@ public class ShieldServerTransportService extends TransportService {
@Override @Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request, TransportRequestOptions options, TransportResponseHandler<T> handler) { public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request, TransportRequestOptions options, TransportResponseHandler<T> handler) {
final ThreadContext.StoredContext original = threadPool.getThreadContext().newStoredContext();
// FIXME this is really just a hack. What happens is that we send a request and we always copy headers over // FIXME this is really just a hack. What happens is that we send a request and we always copy headers over
// Sometimes a system action gets executed like a internal create index request or update mappings request // Sometimes a system action gets executed like a internal create index request or update mappings request
// which means that the user is copied over to system actions and these really fail for internal things... // which means that the user is copied over to system actions and these really fail for internal things...
if ((clientFilter instanceof ClientTransportFilter.Node) && INTERNAL_PREDICATE.test(action)) { if ((clientFilter instanceof ClientTransportFilter.Node) && INTERNAL_PREDICATE.test(action)) {
final ThreadContext.StoredContext original = threadPool.getThreadContext().newStoredContext();
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) {
try { try {
clientFilter.outbound(action, request); clientFilter.outbound(action, request);
@ -94,7 +94,7 @@ public class ShieldServerTransportService extends TransportService {
} else { } else {
try { try {
clientFilter.outbound(action, request); clientFilter.outbound(action, request);
super.sendRequest(node, action, request, options, new ContextRestoreResponseHandler<>(original, handler)); super.sendRequest(node, action, request, options, handler);
} catch (Throwable t) { } catch (Throwable t) {
handler.handleException(new TransportException("failed sending request", t)); handler.handleException(new TransportException("failed sending request", t));
} }