diff --git a/src/main/java/org/elasticsearch/common/transport/DummyTransportAddress.java b/src/main/java/org/elasticsearch/common/transport/DummyTransportAddress.java index 0f0a21de8c4..7d788ed298f 100644 --- a/src/main/java/org/elasticsearch/common/transport/DummyTransportAddress.java +++ b/src/main/java/org/elasticsearch/common/transport/DummyTransportAddress.java @@ -39,6 +39,11 @@ public class DummyTransportAddress implements TransportAddress { return 0; } + @Override + public boolean sameHost(TransportAddress other) { + return other == INSTANCE; + } + @Override public void readFrom(StreamInput in) throws IOException { } diff --git a/src/main/java/org/elasticsearch/common/transport/InetSocketTransportAddress.java b/src/main/java/org/elasticsearch/common/transport/InetSocketTransportAddress.java index 202a93af929..1bc519435de 100644 --- a/src/main/java/org/elasticsearch/common/transport/InetSocketTransportAddress.java +++ b/src/main/java/org/elasticsearch/common/transport/InetSocketTransportAddress.java @@ -67,6 +67,12 @@ public class InetSocketTransportAddress implements TransportAddress { return 1; } + @Override + public boolean sameHost(TransportAddress other) { + return other instanceof InetSocketTransportAddress && + address.getAddress().equals(((InetSocketTransportAddress) other).address.getAddress()); + } + public InetSocketAddress address() { return this.address; } diff --git a/src/main/java/org/elasticsearch/common/transport/LocalTransportAddress.java b/src/main/java/org/elasticsearch/common/transport/LocalTransportAddress.java index 0f7fbcd1035..6a8bc082b6e 100644 --- a/src/main/java/org/elasticsearch/common/transport/LocalTransportAddress.java +++ b/src/main/java/org/elasticsearch/common/transport/LocalTransportAddress.java @@ -47,6 +47,11 @@ public class LocalTransportAddress implements TransportAddress { return 2; } + @Override + public boolean sameHost(TransportAddress other) { + return other instanceof LocalTransportAddress && id.equals(((LocalTransportAddress) other).id); + } + @Override public void readFrom(StreamInput in) throws IOException { id = in.readString(); diff --git a/src/main/java/org/elasticsearch/common/transport/TransportAddress.java b/src/main/java/org/elasticsearch/common/transport/TransportAddress.java index d70a201c9cc..e9f7ac01774 100644 --- a/src/main/java/org/elasticsearch/common/transport/TransportAddress.java +++ b/src/main/java/org/elasticsearch/common/transport/TransportAddress.java @@ -29,4 +29,6 @@ import java.io.Serializable; public interface TransportAddress extends Streamable, Serializable { short uniqueAddressTypeId(); + + boolean sameHost(TransportAddress other); } diff --git a/src/main/java/org/elasticsearch/transport/TransportMessage.java b/src/main/java/org/elasticsearch/transport/TransportMessage.java new file mode 100644 index 00000000000..b684c76fa36 --- /dev/null +++ b/src/main/java/org/elasticsearch/transport/TransportMessage.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import com.google.common.collect.Maps; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Streamable; +import org.elasticsearch.common.transport.TransportAddress; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * + */ +public abstract class TransportMessage> implements Streamable { + + private Map headers; + + private TransportAddress remoteAddress; + + protected TransportMessage() { + } + + protected TransportMessage(TM message) { + // create a new copy of the headers, since we are creating a new request which might have + // its headers changed in the context of that specific request + if (message.getHeaders() != null) { + this.headers = new HashMap<>(message.getHeaders()); + } + } + + + public void remoteAddress(TransportAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + + public TransportAddress remoteAddress() { + return remoteAddress; + } + + @SuppressWarnings("unchecked") + public final TM putHeader(String key, Object value) { + if (headers == null) { + headers = Maps.newHashMap(); + } + headers.put(key, value); + return (TM) this; + } + + @SuppressWarnings("unchecked") + public final V getHeader(String key) { + if (headers == null) { + return null; + } + return (V) headers.get(key); + } + + public Map getHeaders() { + return this.headers; + } + + + @Override + public void readFrom(StreamInput in) throws IOException { + if (in.readBoolean()) { + headers = in.readMap(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (headers == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + out.writeMap(headers); + } + } +} diff --git a/src/main/java/org/elasticsearch/transport/TransportRequest.java b/src/main/java/org/elasticsearch/transport/TransportRequest.java index 14b25777e52..bd2e83db4fa 100644 --- a/src/main/java/org/elasticsearch/transport/TransportRequest.java +++ b/src/main/java/org/elasticsearch/transport/TransportRequest.java @@ -19,18 +19,9 @@ package org.elasticsearch.transport; -import com.google.common.collect.Maps; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Streamable; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - /** */ -public abstract class TransportRequest implements Streamable { +public abstract class TransportRequest extends TransportMessage { public static class Empty extends TransportRequest { @@ -45,55 +36,11 @@ public abstract class TransportRequest implements Streamable { } } - private Map headers; - protected TransportRequest() { - } protected TransportRequest(TransportRequest request) { - // create a new copy of the headers, since we are creating a new request which might have - // its headers changed in the context of that specific request - if (request.getHeaders() != null) { - this.headers = new HashMap<>(request.getHeaders()); - } + super(request); } - @SuppressWarnings("unchecked") - public final TransportRequest putHeader(String key, Object value) { - if (headers == null) { - headers = Maps.newHashMap(); - } - headers.put(key, value); - return this; - } - - @SuppressWarnings("unchecked") - public final V getHeader(String key) { - if (headers == null) { - return null; - } - return (V) headers.get(key); - } - - public Map getHeaders() { - return this.headers; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - if (in.readBoolean()) { - headers = in.readMap(); - } - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - if (headers == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - out.writeMap(headers); - } - } } diff --git a/src/main/java/org/elasticsearch/transport/TransportResponse.java b/src/main/java/org/elasticsearch/transport/TransportResponse.java index 669eb80b860..8ea7cd60d2d 100644 --- a/src/main/java/org/elasticsearch/transport/TransportResponse.java +++ b/src/main/java/org/elasticsearch/transport/TransportResponse.java @@ -19,18 +19,9 @@ package org.elasticsearch.transport; -import com.google.common.collect.Maps; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Streamable; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - /** */ -public abstract class TransportResponse implements Streamable { +public abstract class TransportResponse extends TransportMessage { public static class Empty extends TransportResponse { @@ -45,55 +36,11 @@ public abstract class TransportResponse implements Streamable { } } - private Map headers; - protected TransportResponse() { - } - protected TransportResponse(TransportResponse request) { - // create a new copy of the headers, since we are creating a new request which might have - // its headers changed in the context of that specific request - if (request.getHeaders() != null) { - this.headers = new HashMap<>(request.getHeaders()); - } + protected TransportResponse(TransportResponse response) { + super(response); } - @SuppressWarnings("unchecked") - public final TransportResponse putHeader(String key, Object value) { - if (headers == null) { - headers = Maps.newHashMap(); - } - headers.put(key, value); - return this; - } - - @SuppressWarnings("unchecked") - public final V getHeader(String key) { - if (headers == null) { - return null; - } - return (V) headers.get(key); - } - - public Map getHeaders() { - return this.headers; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - if (in.readBoolean()) { - headers = in.readMap(); - } - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - if (headers == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - out.writeMap(headers); - } - } } diff --git a/src/main/java/org/elasticsearch/transport/local/LocalTransport.java b/src/main/java/org/elasticsearch/transport/local/LocalTransport.java index b9ee0525e29..75c801ee4ec 100644 --- a/src/main/java/org/elasticsearch/transport/local/LocalTransport.java +++ b/src/main/java/org/elasticsearch/transport/local/LocalTransport.java @@ -217,7 +217,7 @@ public class LocalTransport extends AbstractLifecycleComponent implem if (TransportStatus.isError(status)) { handlerResponseError(stream, handler); } else { - handleResponse(stream, handler); + handleResponse(stream, sourceTransport, handler); } } } @@ -242,6 +242,7 @@ public class LocalTransport extends AbstractLifecycleComponent implem throw new ActionNotFoundTransportException("Action [" + action + "] not found"); } final TransportRequest request = handler.newInstance(); + request.remoteAddress(sourceTransport.boundAddress.publishAddress()); request.readFrom(stream); if (handler.executor() == ThreadPool.Names.SAME) { //noinspection unchecked @@ -282,9 +283,9 @@ public class LocalTransport extends AbstractLifecycleComponent implem } } - - protected void handleResponse(StreamInput buffer, final TransportResponseHandler handler) { + protected void handleResponse(StreamInput buffer, LocalTransport sourceTransport, final TransportResponseHandler handler) { final TransportResponse response = handler.newInstance(); + response.remoteAddress(sourceTransport.boundAddress.publishAddress()); try { response.readFrom(buffer); } catch (Throwable e) { diff --git a/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java b/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java index 70c4d750b69..74ad0b37ac8 100644 --- a/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java +++ b/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java @@ -28,6 +28,7 @@ import org.elasticsearch.common.io.ThrowableObjectInputStream; import org.elasticsearch.common.io.stream.CachedStreamInput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.logging.ESLogger; +import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.*; @@ -36,6 +37,7 @@ import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.channel.*; import java.io.IOException; +import java.net.InetSocketAddress; /** * A handler (must be the last one!) that does size based frame decoding and forwards the actual message @@ -122,7 +124,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { if (TransportStatus.isError(status)) { handlerResponseError(wrappedStream, handler); } else { - handleResponse(wrappedStream, handler); + handleResponse(ctx.getChannel(), wrappedStream, handler); } } else { // if its null, skip those bytes @@ -140,8 +142,10 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { wrappedStream.close(); } - private void handleResponse(StreamInput buffer, final TransportResponseHandler handler) { + private void handleResponse(Channel channel, StreamInput buffer, final TransportResponseHandler handler) { final TransportResponse response = handler.newInstance(); + response.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress())); + response.remoteAddress(); try { response.readFrom(buffer); } catch (Throwable e) { @@ -206,6 +210,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { throw new ActionNotFoundTransportException(action); } final TransportRequest request = handler.newInstance(); + request.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress())); request.readFrom(buffer); if (handler.executor() == ThreadPool.Names.SAME) { //noinspection unchecked diff --git a/src/test/java/org/elasticsearch/transport/AbstractSimpleTransportTests.java b/src/test/java/org/elasticsearch/transport/AbstractSimpleTransportTests.java index d590811a95c..dd184421f65 100644 --- a/src/test/java/org/elasticsearch/transport/AbstractSimpleTransportTests.java +++ b/src/test/java/org/elasticsearch/transport/AbstractSimpleTransportTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.test.ElasticsearchTestCase; import org.elasticsearch.test.transport.MockTransportService; @@ -37,6 +38,7 @@ import org.junit.Test; import java.io.IOException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.transport.TransportRequestOptions.options; import static org.hamcrest.Matchers.*; @@ -1048,4 +1050,70 @@ public abstract class AbstractSimpleTransportTests extends ElasticsearchTestCase serviceA.removeHandler("sayHello"); } + + + @Test + public void testHostOnMessages() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(2); + final AtomicReference addressA = new AtomicReference<>(); + final AtomicReference addressB = new AtomicReference<>(); + serviceB.registerHandler("action1", new TransportRequestHandler() { + @Override + public TestRequest newInstance() { + return new TestRequest(); + } + + @Override + public void messageReceived(TestRequest request, TransportChannel channel) throws Exception { + latch.countDown(); + addressA.set(request.remoteAddress()); + channel.sendResponse(new TestResponse()); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public boolean isForceExecution() { + return false; + } + }); + serviceA.sendRequest(nodeB, "action1", new TestRequest(), new TransportResponseHandler() { + @Override + public TestResponse newInstance() { + return new TestResponse(); + } + + @Override + public void handleResponse(TestResponse response) { + latch.countDown(); + addressB.set(response.remoteAddress()); + } + + @Override + public void handleException(TransportException exp) { + latch.countDown(); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); + + if (!latch.await(10, TimeUnit.SECONDS)) { + fail("message round trip did not complete within a sensible time frame"); + } + + assertTrue(nodeA.address().sameHost(addressA.get())); + assertTrue(nodeB.address().sameHost(addressB.get())); + } + + private static class TestRequest extends TransportRequest { + } + + private static class TestResponse extends TransportResponse { + } } diff --git a/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java b/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java index 14721a94b3c..2dc29f17ec7 100644 --- a/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java +++ b/src/test/java/org/elasticsearch/transport/local/SimpleLocalTransportTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.transport.AbstractSimpleTransportTests; -import org.elasticsearch.transport.TransportService; public class SimpleLocalTransportTests extends AbstractSimpleTransportTests {