Plugins: Make rest headers registration pull based

Currently custom headers that should be passed through rest requests are
registered by depending on the RestController in guice and calling a
registration method. This change moves that registration to a getter for
plugins, and makes the RestController take the set of headers on
construction.
This commit is contained in:
Ryan Ernst 2016-07-14 18:45:53 -07:00
parent 05271d58ca
commit 0b514f82a0
10 changed files with 62 additions and 100 deletions

View File

@ -20,10 +20,12 @@
package org.elasticsearch.action; package org.elasticsearch.action;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainAction; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainAction;
import org.elasticsearch.action.admin.cluster.allocation.TransportClusterAllocationExplainAction; import org.elasticsearch.action.admin.cluster.allocation.TransportClusterAllocationExplainAction;
@ -335,7 +337,8 @@ public class ActionModule extends AbstractModule {
actionFilters = setupActionFilters(actionPlugins, ingestEnabled); actionFilters = setupActionFilters(actionPlugins, ingestEnabled);
autoCreateIndex = transportClient ? null : new AutoCreateIndex(settings, resolver); autoCreateIndex = transportClient ? null : new AutoCreateIndex(settings, resolver);
destructiveOperations = new DestructiveOperations(settings, clusterSettings); destructiveOperations = new DestructiveOperations(settings, clusterSettings);
restController = new RestController(settings); Set<String> headers = actionPlugins.stream().flatMap(p -> p.getRestHeaders().stream()).collect(Collectors.toSet());
restController = new RestController(settings, headers);
} }
public Map<String, ActionHandler<?, ?>> getActions() { public Map<String, ActionHandler<?, ?>> getActions() {

View File

@ -28,11 +28,11 @@ import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.rest.RestHandler; import org.elasticsearch.rest.RestHandler;
import java.util.Collection;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import static java.util.Collections.emptyList;
/** /**
* An additional extension point for {@link Plugin}s that extends Elasticsearch's scripting functionality. Implement it like this: * An additional extension point for {@link Plugin}s that extends Elasticsearch's scripting functionality. Implement it like this:
* <pre>{@code * <pre>{@code
@ -50,22 +50,29 @@ public interface ActionPlugin {
* Actions added by this plugin. * Actions added by this plugin.
*/ */
default List<ActionHandler<? extends ActionRequest<?>, ? extends ActionResponse>> getActions() { default List<ActionHandler<? extends ActionRequest<?>, ? extends ActionResponse>> getActions() {
return emptyList(); return Collections.emptyList();
} }
/** /**
* Action filters added by this plugin. * Action filters added by this plugin.
*/ */
default List<Class<? extends ActionFilter>> getActionFilters() { default List<Class<? extends ActionFilter>> getActionFilters() {
return emptyList(); return Collections.emptyList();
} }
/** /**
* Rest handlers added by this plugin. * Rest handlers added by this plugin.
*/ */
default List<Class<? extends RestHandler>> getRestHandlers() { default List<Class<? extends RestHandler>> getRestHandlers() {
return emptyList(); return Collections.emptyList();
} }
public static final class ActionHandler<Request extends ActionRequest<Request>, Response extends ActionResponse> { /**
* Returns headers which should be copied through rest requests on to internal requests.
*/
default Collection<String> getRestHeaders() {
return Collections.emptyList();
}
final class ActionHandler<Request extends ActionRequest<Request>, Response extends ActionResponse> {
private final GenericAction<Request, Response> action; private final GenericAction<Request, Response> action;
private final Class<? extends TransportAction<Request, Response>> transportAction; private final Class<? extends TransportAction<Request, Response>> transportAction;
private final Class<?>[] supportTransportActions; private final Class<?>[] supportTransportActions;

View File

@ -24,14 +24,15 @@ import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.ActionPlugin;
/** /**
* Base handler for REST requests. * Base handler for REST requests.
* <p> * <p>
* This handler makes sure that the headers &amp; context of the handled {@link RestRequest requests} are copied over to * This handler makes sure that the headers &amp; context of the handled {@link RestRequest requests} are copied over to
* the transport requests executed by the associated client. While the context is fully copied over, not all the headers * the transport requests executed by the associated client. While the context is fully copied over, not all the headers
* are copied, but a selected few. It is possible to control what headers are copied over by registering them using * are copied, but a selected few. It is possible to control what headers are copied over by returning them in
* {@link org.elasticsearch.rest.RestController#registerRelevantHeaders(String...)} * {@link ActionPlugin#getRestHeaders()}.
*/ */
public abstract class BaseRestHandler extends AbstractComponent implements RestHandler { public abstract class BaseRestHandler extends AbstractComponent implements RestHandler {
public static final Setting<Boolean> MULTI_ALLOW_EXPLICIT_INDEX = public static final Setting<Boolean> MULTI_ALLOW_EXPLICIT_INDEX =

View File

@ -28,16 +28,17 @@ import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.rest.support.RestUtils; import org.elasticsearch.rest.support.RestUtils;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import static java.util.Collections.emptySet;
import static java.util.Collections.unmodifiableSet; import static java.util.Collections.unmodifiableSet;
import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
import static org.elasticsearch.rest.RestStatus.OK; import static org.elasticsearch.rest.RestStatus.OK;
@ -55,13 +56,15 @@ public class RestController extends AbstractLifecycleComponent {
private final RestHandlerFilter handlerFilter = new RestHandlerFilter(); private final RestHandlerFilter handlerFilter = new RestHandlerFilter();
private Set<String> relevantHeaders = emptySet(); /** Rest headers that are copied to internal requests made during a rest request. */
private final Set<String> headersToCopy;
// non volatile since the assumption is that pre processors are registered on startup // non volatile since the assumption is that pre processors are registered on startup
private RestFilter[] filters = new RestFilter[0]; private RestFilter[] filters = new RestFilter[0];
public RestController(Settings settings) { public RestController(Settings settings, Set<String> headersToCopy) {
super(settings); super(settings);
this.headersToCopy = headersToCopy;
} }
@Override @Override
@ -79,28 +82,6 @@ public class RestController extends AbstractLifecycleComponent {
} }
} }
/**
* Controls which REST headers get copied over from a {@link org.elasticsearch.rest.RestRequest} to
* its corresponding {@link org.elasticsearch.transport.TransportRequest}(s).
*
* By default no headers get copied but it is possible to extend this behaviour via plugins by calling this method.
*/
public synchronized void registerRelevantHeaders(String... headers) {
Set<String> newRelevantHeaders = new HashSet<>(relevantHeaders.size() + headers.length);
newRelevantHeaders.addAll(relevantHeaders);
Collections.addAll(newRelevantHeaders, headers);
relevantHeaders = unmodifiableSet(newRelevantHeaders);
}
/**
* Returns the REST headers that get copied over from a {@link org.elasticsearch.rest.RestRequest} to
* its corresponding {@link org.elasticsearch.transport.TransportRequest}(s).
* By default no headers get copied but it is possible to extend this behaviour via plugins by calling {@link #registerRelevantHeaders(String...)}.
*/
public Set<String> relevantHeaders() {
return relevantHeaders;
}
/** /**
* Registers a pre processor to be executed before the rest request is actually handled. * Registers a pre processor to be executed before the rest request is actually handled.
*/ */
@ -213,7 +194,7 @@ public class RestController extends AbstractLifecycleComponent {
return; return;
} }
try (ThreadContext.StoredContext t = threadContext.stashContext()) { try (ThreadContext.StoredContext t = threadContext.stashContext()) {
for (String key : relevantHeaders) { for (String key : headersToCopy) {
String httpHeader = request.header(key); String httpHeader = request.header(key);
if (httpHeader != null) { if (httpHeader != null) {
threadContext.putHeader(key, httpHeader); threadContext.putHeader(key, httpHeader);

View File

@ -18,6 +18,7 @@
*/ */
package org.elasticsearch.http; package org.elasticsearch.http;
import java.util.Collections;
import java.util.Map; import java.util.Map;
import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreaker;
@ -59,7 +60,7 @@ public class HttpServerTests extends ESTestCase {
inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS); inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
HttpServerTransport httpServerTransport = new TestHttpServerTransport(); HttpServerTransport httpServerTransport = new TestHttpServerTransport();
RestController restController = new RestController(settings); RestController restController = new RestController(settings, Collections.emptySet());
restController.registerHandler(RestRequest.Method.GET, "/", restController.registerHandler(RestRequest.Method.GET, "/",
(request, channel, client) -> channel.sendResponse( (request, channel, client) -> channel.sendResponse(
new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))); new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));

View File

@ -19,6 +19,13 @@
package org.elasticsearch.rest; package org.elasticsearch.rest;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -26,16 +33,6 @@ import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.FakeRestRequest;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq; import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doCallRealMethod;
@ -44,41 +41,10 @@ import static org.mockito.Mockito.verify;
public class RestControllerTests extends ESTestCase { public class RestControllerTests extends ESTestCase {
public void testRegisterRelevantHeaders() throws InterruptedException {
final RestController restController = new RestController(Settings.EMPTY);
int iterations = randomIntBetween(1, 5);
Set<String> headers = new HashSet<>();
ExecutorService executorService = Executors.newFixedThreadPool(iterations);
for (int i = 0; i < iterations; i++) {
int headersCount = randomInt(10);
final Set<String> newHeaders = new HashSet<>();
for (int j = 0; j < headersCount; j++) {
String usefulHeader = randomRealisticUnicodeOfLengthBetween(1, 30);
newHeaders.add(usefulHeader);
}
headers.addAll(newHeaders);
executorService.submit((Runnable) () -> restController.registerRelevantHeaders(newHeaders.toArray(new String[newHeaders.size()])));
}
executorService.shutdown();
assertThat(executorService.awaitTermination(1, TimeUnit.SECONDS), equalTo(true));
String[] relevantHeaders = restController.relevantHeaders().toArray(new String[restController.relevantHeaders().size()]);
assertThat(relevantHeaders.length, equalTo(headers.size()));
Arrays.sort(relevantHeaders);
String[] headersArray = new String[headers.size()];
headersArray = headers.toArray(headersArray);
Arrays.sort(headersArray);
assertThat(relevantHeaders, equalTo(headersArray));
}
public void testApplyRelevantHeaders() throws Exception { public void testApplyRelevantHeaders() throws Exception {
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
final RestController restController = new RestController(Settings.EMPTY) { Set<String> headers = new HashSet<>(Arrays.asList("header.1", "header.2"));
final RestController restController = new RestController(Settings.EMPTY, headers) {
@Override @Override
boolean checkRequestParameters(RestRequest request, RestChannel channel) { boolean checkRequestParameters(RestRequest request, RestChannel channel) {
return true; return true;
@ -89,11 +55,9 @@ public class RestControllerTests extends ESTestCase {
assertEquals("true", threadContext.getHeader("header.1")); assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2")); assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3")); assertNull(threadContext.getHeader("header.3"));
} }
}; };
threadContext.putHeader("header.3", "true"); threadContext.putHeader("header.3", "true");
restController.registerRelevantHeaders("header.1", "header.2");
Map<String, String> restHeaders = new HashMap<>(); Map<String, String> restHeaders = new HashMap<>();
restHeaders.put("header.1", "true"); restHeaders.put("header.1", "true");
restHeaders.put("header.2", "true"); restHeaders.put("header.2", "true");
@ -105,7 +69,7 @@ public class RestControllerTests extends ESTestCase {
} }
public void testCanTripCircuitBreaker() throws Exception { public void testCanTripCircuitBreaker() throws Exception {
RestController controller = new RestController(Settings.EMPTY); RestController controller = new RestController(Settings.EMPTY, Collections.emptySet());
// trip circuit breaker by default // trip circuit breaker by default
controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true)); controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false)); controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));

View File

@ -40,7 +40,7 @@ import static org.hamcrest.CoreMatchers.equalTo;
public class RestFilterChainTests extends ESTestCase { public class RestFilterChainTests extends ESTestCase {
public void testRestFilters() throws Exception { public void testRestFilters() throws Exception {
RestController restController = new RestController(Settings.EMPTY); RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
int numFilters = randomInt(10); int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters); Set<Integer> orders = new HashSet<>(numFilters);
@ -121,7 +121,7 @@ public class RestFilterChainTests extends ESTestCase {
} }
}); });
RestController restController = new RestController(Settings.EMPTY); RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
restController.registerFilter(testFilter); restController.registerFilter(testFilter);
restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() { restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() {

View File

@ -58,6 +58,7 @@ import org.elasticsearch.test.ESTestCase;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import static java.util.Collections.emptyList; import static java.util.Collections.emptyList;
@ -70,7 +71,7 @@ public class RestIndicesActionTests extends ESTestCase {
public void testBuildTable() { public void testBuildTable() {
final Settings settings = Settings.EMPTY; final Settings settings = Settings.EMPTY;
final RestController restController = new RestController(settings); final RestController restController = new RestController(settings, Collections.emptySet());
final RestIndicesAction action = new RestIndicesAction(settings, restController, new IndexNameExpressionResolver(settings)); final RestIndicesAction action = new RestIndicesAction(settings, restController, new IndexNameExpressionResolver(settings));
// build a (semi-)random table // build a (semi-)random table

View File

@ -37,6 +37,7 @@ import org.elasticsearch.snapshots.Snapshot;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
@ -50,7 +51,7 @@ public class RestRecoveryActionTests extends ESTestCase {
public void testRestRecoveryAction() { public void testRestRecoveryAction() {
final Settings settings = Settings.EMPTY; final Settings settings = Settings.EMPTY;
final RestController restController = new RestController(settings); final RestController restController = new RestController(settings, Collections.emptySet());
final RestRecoveryAction action = new RestRecoveryAction(settings, restController, restController); final RestRecoveryAction action = new RestRecoveryAction(settings, restController, restController);
final int totalShards = randomIntBetween(1, 32); final int totalShards = randomIntBetween(1, 32);
final int successfulShards = Math.max(0, totalShards - randomIntBetween(1, 2)); final int successfulShards = Math.max(0, totalShards - randomIntBetween(1, 2));

View File

@ -46,7 +46,6 @@ import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.indices.TermsLookup; import org.elasticsearch.indices.TermsLookup;
import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope; import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After; import org.junit.After;
@ -75,7 +74,7 @@ import static org.hamcrest.Matchers.is;
@ClusterScope(scope = SUITE) @ClusterScope(scope = SUITE)
public class ContextAndHeaderTransportIT extends HttpSmokeTestCase { public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
private static final List<RequestAndHeaders> requests = new CopyOnWriteArrayList<>(); private static final List<RequestAndHeaders> requests = new CopyOnWriteArrayList<>();
private String randomHeaderKey = randomAsciiOfLength(10); private static final String CUSTOM_HEADER = "SomeCustomHeader";
private String randomHeaderValue = randomAsciiOfLength(20); private String randomHeaderValue = randomAsciiOfLength(20);
private String queryIndex = "query-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT); private String queryIndex = "query-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT);
private String lookupIndex = "lookup-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT); private String lookupIndex = "lookup-" + randomAsciiOfLength(10).toLowerCase(Locale.ROOT);
@ -97,6 +96,7 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
protected Collection<Class<? extends Plugin>> nodePlugins() { protected Collection<Class<? extends Plugin>> nodePlugins() {
ArrayList<Class<? extends Plugin>> plugins = new ArrayList<>(super.nodePlugins()); ArrayList<Class<? extends Plugin>> plugins = new ArrayList<>(super.nodePlugins());
plugins.add(ActionLoggingPlugin.class); plugins.add(ActionLoggingPlugin.class);
plugins.add(CustomHeadersPlugin.class);
return plugins; return plugins;
} }
@ -219,21 +219,18 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
} }
public void testThatRelevantHttpHeadersBecomeRequestHeaders() throws Exception { public void testThatRelevantHttpHeadersBecomeRequestHeaders() throws Exception {
String relevantHeaderName = "relevant_" + randomHeaderKey; final String IRRELEVANT_HEADER = "SomeIrrelevantHeader";
for (RestController restController : internalCluster().getInstances(RestController.class)) {
restController.registerRelevantHeaders(relevantHeaderName);
}
try (Response response = getRestClient().performRequest( try (Response response = getRestClient().performRequest(
"GET", "/" + queryIndex + "/_search", "GET", "/" + queryIndex + "/_search",
new BasicHeader(randomHeaderKey, randomHeaderValue), new BasicHeader(relevantHeaderName, randomHeaderValue))) { new BasicHeader(CUSTOM_HEADER, randomHeaderValue), new BasicHeader(IRRELEVANT_HEADER, randomHeaderValue))) {
assertThat(response.getStatusLine().getStatusCode(), equalTo(200)); assertThat(response.getStatusLine().getStatusCode(), equalTo(200));
List<RequestAndHeaders> searchRequests = getRequests(SearchRequest.class); List<RequestAndHeaders> searchRequests = getRequests(SearchRequest.class);
assertThat(searchRequests, hasSize(greaterThan(0))); assertThat(searchRequests, hasSize(greaterThan(0)));
for (RequestAndHeaders requestAndHeaders : searchRequests) { for (RequestAndHeaders requestAndHeaders : searchRequests) {
assertThat(requestAndHeaders.headers.containsKey(relevantHeaderName), is(true)); assertThat(requestAndHeaders.headers.containsKey(CUSTOM_HEADER), is(true));
// was not specified, thus is not included // was not specified, thus is not included
assertThat(requestAndHeaders.headers.containsKey(randomHeaderKey), is(false)); assertThat(requestAndHeaders.headers.containsKey(IRRELEVANT_HEADER), is(false));
} }
} }
} }
@ -273,21 +270,21 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
} }
private void assertRequestContainsHeader(ActionRequest request, Map<String, String> context) { private void assertRequestContainsHeader(ActionRequest request, Map<String, String> context) {
String msg = String.format(Locale.ROOT, "Expected header %s to be in request %s", randomHeaderKey, request.getClass().getName()); String msg = String.format(Locale.ROOT, "Expected header %s to be in request %s", CUSTOM_HEADER, request.getClass().getName());
if (request instanceof IndexRequest) { if (request instanceof IndexRequest) {
IndexRequest indexRequest = (IndexRequest) request; IndexRequest indexRequest = (IndexRequest) request;
msg = String.format(Locale.ROOT, "Expected header %s to be in index request %s/%s/%s", randomHeaderKey, msg = String.format(Locale.ROOT, "Expected header %s to be in index request %s/%s/%s", CUSTOM_HEADER,
indexRequest.index(), indexRequest.type(), indexRequest.id()); indexRequest.index(), indexRequest.type(), indexRequest.id());
} }
assertThat(msg, context.containsKey(randomHeaderKey), is(true)); assertThat(msg, context.containsKey(CUSTOM_HEADER), is(true));
assertThat(context.get(randomHeaderKey).toString(), is(randomHeaderValue)); assertThat(context.get(CUSTOM_HEADER).toString(), is(randomHeaderValue));
} }
/** /**
* a transport client that adds our random header * a transport client that adds our random header
*/ */
private Client transportClient() { private Client transportClient() {
return internalCluster().transportClient().filterWithHeader(Collections.singletonMap(randomHeaderKey, randomHeaderValue)); return internalCluster().transportClient().filterWithHeader(Collections.singletonMap(CUSTOM_HEADER, randomHeaderValue));
} }
public static class ActionLoggingPlugin extends Plugin implements ActionPlugin { public static class ActionLoggingPlugin extends Plugin implements ActionPlugin {
@ -347,4 +344,10 @@ public class ContextAndHeaderTransportIT extends HttpSmokeTestCase {
this.request = request; this.request = request;
} }
} }
public static class CustomHeadersPlugin extends Plugin implements ActionPlugin {
public Collection<String> getRestHeaders() {
return Collections.singleton(CUSTOM_HEADER);
}
}
} }