diff --git a/src/main/java/org/elasticsearch/common/ContextHolder.java b/src/main/java/org/elasticsearch/common/ContextHolder.java new file mode 100644 index 00000000000..6841b72525f --- /dev/null +++ b/src/main/java/org/elasticsearch/common/ContextHolder.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.common; + +import com.carrotsearch.hppc.ObjectObjectAssociativeContainer; +import com.carrotsearch.hppc.ObjectObjectOpenHashMap; +import org.elasticsearch.common.collect.ImmutableOpenMap; + +/** + * + */ +public class ContextHolder { + + private ObjectObjectOpenHashMap context; + + /** + * Attaches the given value to the context. + * + * @return The previous value that was associated with the given key in the context, or + * {@code null} if there was none. + */ + @SuppressWarnings("unchecked") + public final synchronized V putInContext(Object key, Object value) { + if (context == null) { + context = new ObjectObjectOpenHashMap<>(2); + } + return (V) context.put(key, value); + } + + /** + * Attaches the given values to the context + */ + public final synchronized void putAllInContext(ObjectObjectAssociativeContainer map) { + if (map == null) { + return; + } + if (context == null) { + context = new ObjectObjectOpenHashMap<>(map); + } else { + context.putAll(map); + } + } + + /** + * @return The context value that is associated with the given key + * + * @see #putInContext(Object, Object) + */ + @SuppressWarnings("unchecked") + public final synchronized V getFromContext(Object key) { + return context != null ? (V) context.get(key) : null; + } + + /** + * @param defaultValue The default value that should be returned for the given key, if no + * value is currently associated with it. + * + * @return The value that is associated with the given key in the context + * + * @see #putInContext(Object, Object) + */ + @SuppressWarnings("unchecked") + public final synchronized V getFromContext(Object key, V defaultValue) { + V value = getFromContext(key); + return value == null ? defaultValue : value; + } + + /** + * Checks if the context contains an entry with the given key + */ + public final synchronized boolean hasInContext(Object key) { + return context != null && context.containsKey(key); + } + + /** + * @return The number of values attached in the context. + */ + public final synchronized int contextSize() { + return context != null ? context.size() : 0; + } + + /** + * Checks if the context is empty. + */ + public final synchronized boolean isContextEmpty() { + return context == null || context.isEmpty(); + } + + /** + * @return A safe immutable copy of the current context. + */ + public synchronized ImmutableOpenMap getContext() { + return context != null ? ImmutableOpenMap.copyOf(context) : ImmutableOpenMap.of(); + } + + /** + * Copies the context from the given context holder to this context holder. Any shared keys between + * the two context will be overridden by the given context holder. + */ + public synchronized void copyContextFrom(ContextHolder other) { + synchronized (other) { + if (other.context == null) { + return; + } + if (context == null) { + context = new ObjectObjectOpenHashMap<>(other.context); + } else { + context.putAll(other.context); + } + } + } +} diff --git a/src/main/java/org/elasticsearch/rest/BaseRestHandler.java b/src/main/java/org/elasticsearch/rest/BaseRestHandler.java index 9f8fbbb3baa..06f539d9354 100644 --- a/src/main/java/org/elasticsearch/rest/BaseRestHandler.java +++ b/src/main/java/org/elasticsearch/rest/BaseRestHandler.java @@ -32,7 +32,12 @@ import java.util.Collections; import java.util.Set; /** - * Base handler for REST requests + * Base handler for REST requests. + * + * This handler makes sure that the headers & context of the handled {@link RestRequest requests} are copied over to + * the transport requests executed by the associated client. While the context is fully copied over, not all the headers + * are copied, but a selected few. It is possible to control what header are copied over by registering them using + * {@link #addUsefulHeaders(String...)} */ public abstract class BaseRestHandler extends AbstractComponent implements RestHandler { @@ -61,44 +66,45 @@ public abstract class BaseRestHandler extends AbstractComponent implements RestH @Override public final void handleRequest(RestRequest request, RestChannel channel) throws Exception { - handleRequest(request, channel, usefulHeaders.size() == 0 ? client : new HeadersCopyClient(client, request, usefulHeaders)); + handleRequest(request, channel, usefulHeaders.size() == 0 ? client : new HeadersAndContextCopyClient(client, request, usefulHeaders)); } protected abstract void handleRequest(RestRequest request, RestChannel channel, Client client) throws Exception; - static final class HeadersCopyClient extends FilterClient { + static final class HeadersAndContextCopyClient extends FilterClient { private final RestRequest restRequest; - private final Set usefulHeaders; private final IndicesAdmin indicesAdmin; private final ClusterAdmin clusterAdmin; + private final Set headers; - HeadersCopyClient(Client in, RestRequest restRequest, Set usefulHeaders) { + HeadersAndContextCopyClient(Client in, RestRequest restRequest, Set headers) { super(in); this.restRequest = restRequest; - this.usefulHeaders = usefulHeaders; - this.indicesAdmin = new IndicesAdmin(in.admin().indices()); - this.clusterAdmin = new ClusterAdmin(in.admin().cluster()); + this.indicesAdmin = new IndicesAdmin(in.admin().indices(), restRequest, headers); + this.clusterAdmin = new ClusterAdmin(in.admin().cluster(), restRequest, headers); + this.headers = headers; } - private void copyHeaders(ActionRequest request) { - for (String usefulHeader : usefulHeaders) { + private static void copyHeadersAndContext(ActionRequest actionRequest, RestRequest restRequest, Set headers) { + for (String usefulHeader : headers) { String headerValue = restRequest.header(usefulHeader); if (headerValue != null) { - request.putHeader(usefulHeader, headerValue); + actionRequest.putHeader(usefulHeader, headerValue); } } + actionRequest.copyContextFrom(restRequest); } @Override public > ActionFuture execute(Action action, Request request) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); return super.execute(action, request); } @Override public > void execute(Action action, Request request, ActionListener listener) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); super.execute(action, request, listener); } @@ -112,38 +118,50 @@ public abstract class BaseRestHandler extends AbstractComponent implements RestH return indicesAdmin; } - private final class ClusterAdmin extends FilterClient.ClusterAdmin { - private ClusterAdmin(ClusterAdminClient in) { + private static final class ClusterAdmin extends FilterClient.ClusterAdmin { + + private final RestRequest restRequest; + private final Set headers; + + private ClusterAdmin(ClusterAdminClient in, RestRequest restRequest, Set headers) { super(in); + this.restRequest = restRequest; + this.headers = headers; } @Override public > ActionFuture execute(Action action, Request request) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); return super.execute(action, request); } @Override public > void execute(Action action, Request request, ActionListener listener) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); super.execute(action, request, listener); } } private final class IndicesAdmin extends FilterClient.IndicesAdmin { - private IndicesAdmin(IndicesAdminClient in) { + + private final RestRequest restRequest; + private final Set headers; + + private IndicesAdmin(IndicesAdminClient in, RestRequest restRequest, Set headers) { super(in); + this.restRequest = restRequest; + this.headers = headers; } @Override public > ActionFuture execute(Action action, Request request) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); return super.execute(action, request); } @Override public > void execute(Action action, Request request, ActionListener listener) { - copyHeaders(request); + copyHeadersAndContext(request, restRequest, headers); super.execute(action, request, listener); } } diff --git a/src/main/java/org/elasticsearch/rest/RestRequest.java b/src/main/java/org/elasticsearch/rest/RestRequest.java index d45013ee947..7f94a6ca045 100644 --- a/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -21,6 +21,7 @@ package org.elasticsearch.rest; import org.elasticsearch.ElasticsearchIllegalArgumentException; import org.elasticsearch.common.Booleans; +import org.elasticsearch.common.ContextHolder; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; @@ -38,7 +39,7 @@ import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; /** * */ -public abstract class RestRequest implements ToXContent.Params { +public abstract class RestRequest extends ContextHolder implements ToXContent.Params { public enum Method { GET, POST, PUT, DELETE, OPTIONS, HEAD diff --git a/src/main/java/org/elasticsearch/transport/TransportMessage.java b/src/main/java/org/elasticsearch/transport/TransportMessage.java index 24dc21ce1d3..5a82bd10e35 100644 --- a/src/main/java/org/elasticsearch/transport/TransportMessage.java +++ b/src/main/java/org/elasticsearch/transport/TransportMessage.java @@ -20,7 +20,7 @@ package org.elasticsearch.transport; import com.carrotsearch.hppc.ObjectObjectOpenHashMap; -import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.ContextHolder; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; @@ -33,12 +33,10 @@ import java.util.Map; import java.util.Set; /** - * + * The transport message is also a {@link ContextHolder context holder} that holds transient context, that is, + * the context is not serialized with message. */ -public abstract class TransportMessage> implements Streamable { - - // a transient (not serialized with the request) key/value registry - private ObjectObjectOpenHashMap context; +public abstract class TransportMessage> extends ContextHolder implements Streamable { private Map headers; @@ -54,9 +52,7 @@ public abstract class TransportMessage> implemen if (((TransportMessage) message).headers != null) { this.headers = new HashMap<>(((TransportMessage) message).headers); } - if (((TransportMessage) message).context != null) { - this.context = new ObjectObjectOpenHashMap<>(((TransportMessage) message).context); - } + copyContextFrom(message); } public void remoteAddress(TransportAddress remoteAddress) { @@ -89,76 +85,6 @@ public abstract class TransportMessage> implemen return headers != null ? headers.keySet() : Collections.emptySet(); } - /** - * Attaches the given transient value to the request - this value will not be serialized - * along with the request. - * - * There are many use cases such data is required, for example, when processing the - * request headers and building other constructs from them, one could "cache" the - * already built construct to avoid reprocessing the header over and over again. - * - * @return The previous value that was associated with the given key in the context, or - * {@code null} if there was none. - */ - @SuppressWarnings("unchecked") - public final synchronized V putInContext(Object key, Object value) { - if (context == null) { - context = new ObjectObjectOpenHashMap<>(2); - } - return (V) context.put(key, value); - } - - /** - * @return The transient value that is associated with the given key in the request context - * @see #putInContext(Object, Object) - */ - @SuppressWarnings("unchecked") - public final synchronized V getFromContext(Object key) { - return context != null ? (V) context.get(key) : null; - } - - /** - * @param defaultValue The default value that should be returned for the given key, if no - * value is currently associated with it. - * - * @return The transient value that is associated with the given key in the request context - * - * @see #putInContext(Object, Object) - */ - @SuppressWarnings("unchecked") - public final synchronized V getFromContext(Object key, V defaultValue) { - V value = getFromContext(key); - return value == null ? defaultValue : value; - } - - /** - * Checks if the request context contains an entry with the given key - */ - public final synchronized boolean hasInContext(Object key) { - return context != null && context.containsKey(key); - } - - /** - * @return The number of transient values attached in the request context. - */ - public final synchronized int contextSize() { - return context != null ? context.size() : 0; - } - - /** - * Checks if the request context is empty. - */ - public final synchronized boolean isContextEmpty() { - return context == null || context.isEmpty(); - } - - /** - * @return A safe immutable copy of the current context of this request. - */ - public synchronized ImmutableOpenMap getContext() { - return context != null ? ImmutableOpenMap.copyOf(context) : ImmutableOpenMap.of(); - } - @Override public void readFrom(StreamInput in) throws IOException { headers = in.readBoolean() ? in.readMap() : null; diff --git a/src/test/java/org/elasticsearch/rest/FakeRestRequest.java b/src/test/java/org/elasticsearch/rest/FakeRestRequest.java index e9f6dafe580..60612ac3770 100644 --- a/src/test/java/org/elasticsearch/rest/FakeRestRequest.java +++ b/src/test/java/org/elasticsearch/rest/FakeRestRequest.java @@ -29,11 +29,14 @@ class FakeRestRequest extends RestRequest { private final Map headers; FakeRestRequest() { - this(new HashMap()); + this(new HashMap(), new HashMap()); } - FakeRestRequest(Map headers) { + FakeRestRequest(Map headers, Map context) { this.headers = headers; + for (Map.Entry entry : context.entrySet()) { + putInContext(entry.getKey(), entry.getValue()); + } } @Override diff --git a/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java b/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java index b3ec5a02408..49277fbc3bf 100644 --- a/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java +++ b/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.rest; +import com.google.common.collect.Maps; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.*; import org.elasticsearch.action.admin.cluster.health.ClusterHealthRequest; @@ -34,6 +35,7 @@ import org.elasticsearch.client.*; import org.elasticsearch.client.support.AbstractClient; import org.elasticsearch.client.support.AbstractClusterAdminClient; import org.elasticsearch.client.support.AbstractIndicesAdminClient; +import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ElasticsearchTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -44,9 +46,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.rest.BaseRestHandler.HeadersCopyClient; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.Matchers.*; public class HeadersCopyClientTests extends ElasticsearchTestCase { @@ -89,116 +91,158 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { @Test public void testCopyHeadersRequest() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); - HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + Map expectedHeaders = new HashMap<>(); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, restContext), usefulRestHeaders); SearchRequest searchRequest = Requests.searchRequest(); - putHeaders(searchRequest, existingTransportHeaders); - assertHeaders(searchRequest, existingTransportHeaders); + putHeaders(searchRequest, transportHeaders); + putContext(searchRequest, transportContext); + assertHeaders(searchRequest, transportHeaders); client.search(searchRequest); assertHeaders(searchRequest, expectedHeaders); + assertContext(searchRequest, expectedContext); GetRequest getRequest = Requests.getRequest("index"); - putHeaders(getRequest, existingTransportHeaders); - assertHeaders(getRequest, existingTransportHeaders); + putHeaders(getRequest, transportHeaders); + putContext(getRequest, transportContext); + assertHeaders(getRequest, transportHeaders); client.get(getRequest); assertHeaders(getRequest, expectedHeaders); + assertContext(getRequest, expectedContext); IndexRequest indexRequest = Requests.indexRequest(); - putHeaders(indexRequest, existingTransportHeaders); - assertHeaders(indexRequest, existingTransportHeaders); + putHeaders(indexRequest, transportHeaders); + putContext(indexRequest, transportContext); + assertHeaders(indexRequest, transportHeaders); client.index(indexRequest); assertHeaders(indexRequest, expectedHeaders); + assertContext(indexRequest, expectedContext); } @Test public void testCopyHeadersClusterAdminRequest() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, expectedContext), usefulRestHeaders); ClusterHealthRequest clusterHealthRequest = Requests.clusterHealthRequest(); - putHeaders(clusterHealthRequest, existingTransportHeaders); - assertHeaders(clusterHealthRequest, existingTransportHeaders); + putHeaders(clusterHealthRequest, transportHeaders); + putContext(clusterHealthRequest, transportContext); + assertHeaders(clusterHealthRequest, transportHeaders); client.admin().cluster().health(clusterHealthRequest); assertHeaders(clusterHealthRequest, expectedHeaders); + assertContext(clusterHealthRequest, expectedContext); ClusterStateRequest clusterStateRequest = Requests.clusterStateRequest(); - putHeaders(clusterStateRequest, existingTransportHeaders); - assertHeaders(clusterStateRequest, existingTransportHeaders); + putHeaders(clusterStateRequest, transportHeaders); + putContext(clusterStateRequest, transportContext); + assertHeaders(clusterStateRequest, transportHeaders); client.admin().cluster().state(clusterStateRequest); assertHeaders(clusterStateRequest, expectedHeaders); + assertContext(clusterStateRequest, expectedContext); ClusterStatsRequest clusterStatsRequest = Requests.clusterStatsRequest(); - putHeaders(clusterStatsRequest, existingTransportHeaders); - assertHeaders(clusterStatsRequest, existingTransportHeaders); + putHeaders(clusterStatsRequest, transportHeaders); + putContext(clusterStatsRequest, transportContext); + assertHeaders(clusterStatsRequest, transportHeaders); client.admin().cluster().clusterStats(clusterStatsRequest); assertHeaders(clusterStatsRequest, expectedHeaders); + assertContext(clusterStatsRequest, expectedContext); } @Test public void testCopyHeadersIndicesAdminRequest() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, restContext), usefulRestHeaders); CreateIndexRequest createIndexRequest = Requests.createIndexRequest("test"); - putHeaders(createIndexRequest, existingTransportHeaders); - assertHeaders(createIndexRequest, existingTransportHeaders); + putHeaders(createIndexRequest, transportHeaders); + putContext(createIndexRequest, transportContext); + assertHeaders(createIndexRequest, transportHeaders); client.admin().indices().create(createIndexRequest); assertHeaders(createIndexRequest, expectedHeaders); + assertContext(createIndexRequest, expectedContext); CloseIndexRequest closeIndexRequest = Requests.closeIndexRequest("test"); - putHeaders(closeIndexRequest, existingTransportHeaders); - assertHeaders(closeIndexRequest, existingTransportHeaders); + putHeaders(closeIndexRequest, transportHeaders); + putContext(closeIndexRequest, transportContext); + assertHeaders(closeIndexRequest, transportHeaders); client.admin().indices().close(closeIndexRequest); assertHeaders(closeIndexRequest, expectedHeaders); + assertContext(closeIndexRequest, expectedContext); FlushRequest flushRequest = Requests.flushRequest(); - putHeaders(flushRequest, existingTransportHeaders); - assertHeaders(flushRequest, existingTransportHeaders); + putHeaders(flushRequest, transportHeaders); + putContext(flushRequest, transportContext); + assertHeaders(flushRequest, transportHeaders); client.admin().indices().flush(flushRequest); assertHeaders(flushRequest, expectedHeaders); + assertContext(flushRequest, expectedContext); } @Test public void testCopyHeadersRequestBuilder() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, restContext), usefulRestHeaders); ActionRequestBuilder requestBuilders [] = new ActionRequestBuilder[] { client.prepareIndex("index", "type"), @@ -212,26 +256,34 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { }; for (ActionRequestBuilder requestBuilder : requestBuilders) { - putHeaders(requestBuilder.request(), existingTransportHeaders); - assertHeaders(requestBuilder.request(), existingTransportHeaders); + putHeaders(requestBuilder.request(), transportHeaders); + putContext(requestBuilder.request(), transportContext); + assertHeaders(requestBuilder.request(), transportHeaders); requestBuilder.get(); assertHeaders(requestBuilder.request(), expectedHeaders); + assertContext(requestBuilder.request(), expectedContext); } } @Test public void testCopyHeadersClusterAdminRequestBuilder() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, restContext), usefulRestHeaders); ActionRequestBuilder requestBuilders [] = new ActionRequestBuilder[] { client.admin().cluster().prepareNodesInfo(), @@ -243,26 +295,34 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { }; for (ActionRequestBuilder requestBuilder : requestBuilders) { - putHeaders(requestBuilder.request(), existingTransportHeaders); - assertHeaders(requestBuilder.request(), existingTransportHeaders); + putHeaders(requestBuilder.request(), transportHeaders); + putContext(requestBuilder.request(), transportContext); + assertHeaders(requestBuilder.request(), transportHeaders); requestBuilder.get(); assertHeaders(requestBuilder.request(), expectedHeaders); + assertContext(requestBuilder.request(), expectedContext); } } @Test public void testCopyHeadersIndicesAdminRequestBuilder() { - Map existingTransportHeaders = randomHeaders(randomIntBetween(0, 10)); + Map transportHeaders = randomHeaders(randomIntBetween(0, 10)); Map restHeaders = randomHeaders(randomIntBetween(0, 10)); - Map leftRestHeaders = randomHeadersFrom(restHeaders); - Set usefulRestHeaders = new HashSet<>(leftRestHeaders.keySet()); - usefulRestHeaders.addAll(randomHeaders(randomIntBetween(0, 10), "useful-").keySet()); + Map copiedHeaders = randomHeadersFrom(restHeaders); + Set usefulRestHeaders = new HashSet<>(copiedHeaders.keySet()); + usefulRestHeaders.addAll(randomMap(randomIntBetween(0, 10), "useful-").keySet()); + Map restContext = randomContext(randomIntBetween(0, 10)); + Map transportContext = Maps.difference(randomContext(randomIntBetween(0, 10)), restContext).entriesOnlyOnLeft(); HashMap expectedHeaders = new HashMap<>(); - expectedHeaders.putAll(existingTransportHeaders); - expectedHeaders.putAll(leftRestHeaders); + expectedHeaders.putAll(transportHeaders); + expectedHeaders.putAll(copiedHeaders); - Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders), usefulRestHeaders); + Map expectedContext = new HashMap<>(); + expectedContext.putAll(transportContext); + expectedContext.putAll(restContext); + + Client client = client(new NoOpClient(), new FakeRestRequest(restHeaders, restContext), usefulRestHeaders); ActionRequestBuilder requestBuilders [] = new ActionRequestBuilder[] { client.admin().indices().prepareValidateQuery(), @@ -275,18 +335,32 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { }; for (ActionRequestBuilder requestBuilder : requestBuilders) { - putHeaders(requestBuilder.request(), existingTransportHeaders); - assertHeaders(requestBuilder.request(), existingTransportHeaders); + putHeaders(requestBuilder.request(), transportHeaders); + putContext(requestBuilder.request(), transportContext); + assertHeaders(requestBuilder.request(), transportHeaders); requestBuilder.get(); assertHeaders(requestBuilder.request(), expectedHeaders); + assertContext(requestBuilder.request(), expectedContext); } } private static Map randomHeaders(int count) { - return randomHeaders(count, "header-"); + return randomMap(count, "header-"); } - private static Map randomHeaders(int count, String prefix) { + private static Map randomContext(int count) { + return randomMap(count, "context-"); + } + + private static Map randomMap(int count, String prefix) { + Map headers = new HashMap<>(); + for (int i = 0; i < count; i++) { + headers.put(prefix + randomInt(30), randomAsciiOfLength(10)); + } + return headers; + } + + private static Map randomContext(int count, String prefix) { Map headers = new HashMap<>(); for (int i = 0; i < count; i++) { headers.put(prefix + randomInt(30), randomRealisticUnicodeOfLengthBetween(1, 20)); @@ -312,7 +386,7 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { if (usefulRestHeaders.isEmpty() && randomBoolean()) { return noOpClient; } - return new HeadersCopyClient(noOpClient, restRequest, usefulRestHeaders); + return new BaseRestHandler.HeadersAndContextCopyClient(noOpClient, restRequest, usefulRestHeaders); } private static void putHeaders(ActionRequest request, Map headers) { @@ -321,6 +395,12 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { } } + private static void putContext(ActionRequest request, Map context) { + for (Map.Entry header : context.entrySet()) { + request.putInContext(header.getKey(), header.getValue()); + } + } + private static void assertHeaders(ActionRequest request, Map headers) { if (headers.size() == 0) { assertThat(request.getHeaders() == null || request.getHeaders().size() == 0, equalTo(true)); @@ -333,6 +413,19 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { } } + private static void assertContext(ActionRequest request, Map context) { + if (context.size() == 0) { + assertThat(request.isContextEmpty(), is(true)); + } else { + ImmutableOpenMap map = request.getContext(); + assertThat(map, notNullValue()); + assertThat(map.size(), equalTo(context.size())); + for (Object key : map.keys()) { + assertThat(context.get(key), equalTo(request.getFromContext(key))); + } + } + } + private static class NoOpClient extends AbstractClient implements AdminClient { @Override diff --git a/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/src/test/java/org/elasticsearch/rest/RestRequestTests.java new file mode 100644 index 00000000000..f4fc4f6cf28 --- /dev/null +++ b/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -0,0 +1,114 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.rest; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.test.ElasticsearchTestCase; +import org.junit.Test; + +import java.util.Map; + +import static org.hamcrest.Matchers.*; + +/** + * + */ +public class RestRequestTests extends ElasticsearchTestCase { + + @Test + public void testContext() throws Exception { + int count = randomInt(10); + Request request = new Request(); + for (int i = 0; i < count; i++) { + request.putInContext("key" + i, "val" + i); + } + assertThat(request.isContextEmpty(), is(count == 0)); + assertThat(request.contextSize(), is(count)); + ImmutableOpenMap ctx = request.getContext(); + for (int i = 0; i < count; i++) { + assertThat(request.hasInContext("key" + i), is(true)); + assertThat((String) request.getFromContext("key" + i), equalTo("val" + i)); + assertThat((String) ctx.get("key" + i), equalTo("val" + i)); + } + } + + public static class Request extends RestRequest { + @Override + public Method method() { + return null; + } + + @Override + public String uri() { + return null; + } + + @Override + public String rawPath() { + return null; + } + + @Override + public boolean hasContent() { + return false; + } + + @Override + public boolean contentUnsafe() { + return false; + } + + @Override + public BytesReference content() { + return null; + } + + @Override + public String header(String name) { + return null; + } + + @Override + public Iterable> headers() { + return null; + } + + @Override + public boolean hasParam(String key) { + return false; + } + + @Override + public String param(String key) { + return null; + } + + @Override + public Map params() { + return null; + } + + @Override + public String param(String key, String defaultValue) { + return null; + } + } +}