diff --git a/src/main/java/org/elasticsearch/transport/TransportMessage.java b/src/main/java/org/elasticsearch/transport/TransportMessage.java index b684c76fa36..fd7d4ef707d 100644 --- a/src/main/java/org/elasticsearch/transport/TransportMessage.java +++ b/src/main/java/org/elasticsearch/transport/TransportMessage.java @@ -19,36 +19,58 @@ package org.elasticsearch.transport; -import com.google.common.collect.Maps; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; import org.elasticsearch.common.transport.TransportAddress; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; /** * */ public abstract class TransportMessage> implements Streamable { + // a transient (not serialized with the request) key/value registry + private final ConcurrentMap context; + private Map headers; private TransportAddress remoteAddress; protected TransportMessage() { + context = new ConcurrentHashMap<>(); } protected TransportMessage(TM message) { - // create a new copy of the headers, since we are creating a new request which might have - // its headers changed in the context of that specific request - if (message.getHeaders() != null) { - this.headers = new HashMap<>(message.getHeaders()); + // create a new copy of the headers/context, since we are creating a new request + // which might have its headers/context changed in the context of that specific request + + if (((TransportMessage) message).headers != null) { + this.headers = new HashMap<>(((TransportMessage) message).headers); } + this.context = new ConcurrentHashMap<>(((TransportMessage) message).context); } + /** + * The request context enables attaching transient data with the request - data + * that is not serialized along with the request. + * + * There are many use cases such data is required, for example, when processing the + * request headers and building other constructs from them, one could "cache" the + * already built construct to avoid reprocessing the header over and over again. + * + * @return The request context + */ + public ConcurrentMap context() { + return context; + } public void remoteAddress(TransportAddress remoteAddress) { this.remoteAddress = remoteAddress; @@ -61,7 +83,7 @@ public abstract class TransportMessage> implemen @SuppressWarnings("unchecked") public final TM putHeader(String key, Object value) { if (headers == null) { - headers = Maps.newHashMap(); + headers = new HashMap<>(); } headers.put(key, value); return (TM) this; @@ -69,22 +91,20 @@ public abstract class TransportMessage> implemen @SuppressWarnings("unchecked") public final V getHeader(String key) { - if (headers == null) { - return null; - } - return (V) headers.get(key); + return headers != null ? (V) headers.get(key) : null; } - public Map getHeaders() { - return this.headers; + public final boolean hasHeader(String key) { + return headers != null && headers.containsKey(key); } + public Set getHeaders() { + return headers != null ? headers.keySet() : Collections.emptySet(); + } @Override public void readFrom(StreamInput in) throws IOException { - if (in.readBoolean()) { - headers = in.readMap(); - } + headers = in.readBoolean() ? in.readMap() : null; } @Override @@ -96,4 +116,5 @@ public abstract class TransportMessage> implemen out.writeMap(headers); } } + } diff --git a/src/test/java/org/elasticsearch/client/AbstractClientHeadersTests.java b/src/test/java/org/elasticsearch/client/AbstractClientHeadersTests.java index 222cef8315b..0b5455caa23 100644 --- a/src/test/java/org/elasticsearch/client/AbstractClientHeadersTests.java +++ b/src/test/java/org/elasticsearch/client/AbstractClientHeadersTests.java @@ -53,10 +53,12 @@ import org.elasticsearch.client.support.Headers; import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ElasticsearchTestCase; +import org.elasticsearch.transport.TransportMessage; import org.junit.After; import org.junit.Before; import org.junit.Test; +import java.util.HashMap; import java.util.Map; import static org.hamcrest.Matchers.*; @@ -135,9 +137,12 @@ public abstract class AbstractClientHeadersTests extends ElasticsearchTestCase { private final String action; private final Map headers; - public InternalException(String action, Map headers) { + public InternalException(String action, TransportMessage message) { this.action = action; - this.headers = headers; + this.headers = new HashMap<>(); + for (String key : message.getHeaders()) { + headers.put(key, message.getHeader(key)); + } } } diff --git a/src/test/java/org/elasticsearch/client/node/NodeClientHeadersTests.java b/src/test/java/org/elasticsearch/client/node/NodeClientHeadersTests.java index 503d0d997a2..6f4b9ddffad 100644 --- a/src/test/java/org/elasticsearch/client/node/NodeClientHeadersTests.java +++ b/src/test/java/org/elasticsearch/client/node/NodeClientHeadersTests.java @@ -84,7 +84,7 @@ public class NodeClientHeadersTests extends AbstractClientHeadersTests { @Override protected void doExecute(ActionRequest request, ActionListener listener) { - listener.onFailure(new InternalException(actionName, request.getHeaders())); + listener.onFailure(new InternalException(actionName, request)); } } diff --git a/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java b/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java index fb989134e19..fe43475e4b9 100644 --- a/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java +++ b/src/test/java/org/elasticsearch/client/transport/TransportClientHeadersTests.java @@ -68,7 +68,7 @@ public class TransportClientHeadersTests extends AbstractClientHeadersTests { ((TransportResponseHandler) handler).handleResponse(new NodesInfoResponse(ClusterName.DEFAULT, new NodeInfo[0])); return; } - handler.handleException(new TransportException("", new InternalException(action, request.getHeaders()))); + handler.handleException(new TransportException("", new InternalException(action, request))); } @Override diff --git a/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java b/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java index 5536a1c03a7..e987c0a5f35 100644 --- a/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java +++ b/src/test/java/org/elasticsearch/rest/HeadersCopyClientTests.java @@ -325,8 +325,8 @@ public class HeadersCopyClientTests extends ElasticsearchTestCase { } else { assertThat(request.getHeaders(), notNullValue()); assertThat(request.getHeaders().size(), equalTo(headers.size())); - for (Map.Entry entry : request.getHeaders().entrySet()) { - assertThat(headers.get(entry.getKey()), equalTo(entry.getValue())); + for (String key : request.getHeaders()) { + assertThat(headers.get(key), equalTo(request.getHeader(key))); } } } diff --git a/src/test/java/org/elasticsearch/test/ElasticsearchTestCase.java b/src/test/java/org/elasticsearch/test/ElasticsearchTestCase.java index e1df0fa2b64..14c0e5a1b47 100644 --- a/src/test/java/org/elasticsearch/test/ElasticsearchTestCase.java +++ b/src/test/java/org/elasticsearch/test/ElasticsearchTestCase.java @@ -19,10 +19,14 @@ package org.elasticsearch.test; import com.carrotsearch.randomizedtesting.RandomizedTest; -import com.carrotsearch.randomizedtesting.annotations.*; +import com.carrotsearch.randomizedtesting.annotations.Listeners; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope.Scope; +import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import org.apache.lucene.search.FieldCache; import org.apache.lucene.store.MockDirectoryWrapper; import org.apache.lucene.util.AbstractRandomizedTest; diff --git a/src/test/java/org/elasticsearch/test/transport/ConfigurableErrorNettyTransportModule.java b/src/test/java/org/elasticsearch/test/transport/ConfigurableErrorNettyTransportModule.java index 32f2a4f5c6d..da2a84fede8 100644 --- a/src/test/java/org/elasticsearch/test/transport/ConfigurableErrorNettyTransportModule.java +++ b/src/test/java/org/elasticsearch/test/transport/ConfigurableErrorNettyTransportModule.java @@ -95,8 +95,8 @@ public class ConfigurableErrorNettyTransportModule extends AbstractModule { final TransportRequest request = handler.newInstance(); request.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress())); request.readFrom(buffer); - if (request.getHeaders() != null && request.getHeaders().containsKey("ERROR")) { - throw new ElasticsearchException((String) request.getHeaders().get("ERROR")); + if (request.hasHeader("ERROR")) { + throw new ElasticsearchException((String) request.getHeader("ERROR")); } if (handler.executor() == ThreadPool.Names.SAME) { //noinspection unchecked diff --git a/src/test/java/org/elasticsearch/transport/TransportMessageTests.java b/src/test/java/org/elasticsearch/transport/TransportMessageTests.java new file mode 100644 index 00000000000..6144008320c --- /dev/null +++ b/src/test/java/org/elasticsearch/transport/TransportMessageTests.java @@ -0,0 +1,80 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.BytesStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.test.ElasticsearchTestCase; +import org.junit.Test; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +/** + * + */ +public class TransportMessageTests extends ElasticsearchTestCase { + + @Test + public void testTransientContext() throws Exception { + Message message = new Message(); + message.putHeader("key1", "value1"); + message.putHeader("key2", "value2"); + message.context().put("key3", "value3"); + + BytesStreamOutput out = new BytesStreamOutput(); + out.setVersion(Version.CURRENT); + message.writeTo(out); + BytesStreamInput in = new BytesStreamInput(out.bytes()); + in.setVersion(Version.CURRENT); + message = new Message(); + message.readFrom(in); + assertThat(message.getHeaders().size(), is(2)); + assertThat((String) message.getHeader("key1"), equalTo("value1")); + assertThat((String) message.getHeader("key2"), equalTo("value2")); + assertThat(message.context().isEmpty(), is(true)); + } + + @Test + public void testCopyHeadersAndContext() throws Exception { + Message m1 = new Message(); + m1.putHeader("key1", "value1"); + m1.putHeader("key2", "value2"); + m1.context().put("key3", "value3"); + + Message m2 = new Message(m1); + + assertThat(m2.getHeaders().size(), is(2)); + assertThat((String) m2.getHeader("key1"), equalTo("value1")); + assertThat((String) m2.getHeader("key2"), equalTo("value2")); + assertThat((String) m2.context().get("key3"), equalTo("value3")); + } + + private static class Message extends TransportMessage { + + private Message() { + } + + private Message(Message message) { + super(message); + } + } +}