Cleanup of the transport request/response messages

Now both TransportRequest and TransportResponse inherit from a base TransportMessage that holds the message headers and also now added the remote transport address (where this message came from).
This commit is contained in:
uboness 2014-07-11 14:40:51 +02:00
parent c4c0270c52
commit 25a21c6a01
11 changed files with 201 additions and 117 deletions

View File

@ -39,6 +39,11 @@ public class DummyTransportAddress implements TransportAddress {
return 0; return 0;
} }
@Override
public boolean sameHost(TransportAddress other) {
return other == INSTANCE;
}
@Override @Override
public void readFrom(StreamInput in) throws IOException { public void readFrom(StreamInput in) throws IOException {
} }

View File

@ -67,6 +67,12 @@ public class InetSocketTransportAddress implements TransportAddress {
return 1; return 1;
} }
@Override
public boolean sameHost(TransportAddress other) {
return other instanceof InetSocketTransportAddress &&
address.getAddress().equals(((InetSocketTransportAddress) other).address.getAddress());
}
public InetSocketAddress address() { public InetSocketAddress address() {
return this.address; return this.address;
} }

View File

@ -47,6 +47,11 @@ public class LocalTransportAddress implements TransportAddress {
return 2; return 2;
} }
@Override
public boolean sameHost(TransportAddress other) {
return other instanceof LocalTransportAddress && id.equals(((LocalTransportAddress) other).id);
}
@Override @Override
public void readFrom(StreamInput in) throws IOException { public void readFrom(StreamInput in) throws IOException {
id = in.readString(); id = in.readString();

View File

@ -29,4 +29,6 @@ import java.io.Serializable;
public interface TransportAddress extends Streamable, Serializable { public interface TransportAddress extends Streamable, Serializable {
short uniqueAddressTypeId(); short uniqueAddressTypeId();
boolean sameHost(TransportAddress other);
} }

View File

@ -0,0 +1,99 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.transport;
import com.google.common.collect.Maps;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import org.elasticsearch.common.transport.TransportAddress;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/**
*
*/
public abstract class TransportMessage<TM extends TransportMessage<TM>> implements Streamable {
private Map<String, Object> headers;
private TransportAddress remoteAddress;
protected TransportMessage() {
}
protected TransportMessage(TM message) {
// create a new copy of the headers, since we are creating a new request which might have
// its headers changed in the context of that specific request
if (message.getHeaders() != null) {
this.headers = new HashMap<>(message.getHeaders());
}
}
public void remoteAddress(TransportAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
public TransportAddress remoteAddress() {
return remoteAddress;
}
@SuppressWarnings("unchecked")
public final TM putHeader(String key, Object value) {
if (headers == null) {
headers = Maps.newHashMap();
}
headers.put(key, value);
return (TM) this;
}
@SuppressWarnings("unchecked")
public final <V> V getHeader(String key) {
if (headers == null) {
return null;
}
return (V) headers.get(key);
}
public Map<String, Object> getHeaders() {
return this.headers;
}
@Override
public void readFrom(StreamInput in) throws IOException {
if (in.readBoolean()) {
headers = in.readMap();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
if (headers == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(headers);
}
}
}

View File

@ -19,18 +19,9 @@
package org.elasticsearch.transport; package org.elasticsearch.transport;
import com.google.common.collect.Maps;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/** /**
*/ */
public abstract class TransportRequest implements Streamable { public abstract class TransportRequest extends TransportMessage<TransportRequest> {
public static class Empty extends TransportRequest { public static class Empty extends TransportRequest {
@ -45,55 +36,11 @@ public abstract class TransportRequest implements Streamable {
} }
} }
private Map<String, Object> headers;
protected TransportRequest() { protected TransportRequest() {
} }
protected TransportRequest(TransportRequest request) { protected TransportRequest(TransportRequest request) {
// create a new copy of the headers, since we are creating a new request which might have super(request);
// its headers changed in the context of that specific request
if (request.getHeaders() != null) {
this.headers = new HashMap<>(request.getHeaders());
}
} }
@SuppressWarnings("unchecked")
public final TransportRequest putHeader(String key, Object value) {
if (headers == null) {
headers = Maps.newHashMap();
}
headers.put(key, value);
return this;
}
@SuppressWarnings("unchecked")
public final <V> V getHeader(String key) {
if (headers == null) {
return null;
}
return (V) headers.get(key);
}
public Map<String, Object> getHeaders() {
return this.headers;
}
@Override
public void readFrom(StreamInput in) throws IOException {
if (in.readBoolean()) {
headers = in.readMap();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
if (headers == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(headers);
}
}
} }

View File

@ -19,18 +19,9 @@
package org.elasticsearch.transport; package org.elasticsearch.transport;
import com.google.common.collect.Maps;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
/** /**
*/ */
public abstract class TransportResponse implements Streamable { public abstract class TransportResponse extends TransportMessage<TransportResponse> {
public static class Empty extends TransportResponse { public static class Empty extends TransportResponse {
@ -45,55 +36,11 @@ public abstract class TransportResponse implements Streamable {
} }
} }
private Map<String, Object> headers;
protected TransportResponse() { protected TransportResponse() {
} }
protected TransportResponse(TransportResponse request) { protected TransportResponse(TransportResponse response) {
// create a new copy of the headers, since we are creating a new request which might have super(response);
// its headers changed in the context of that specific request
if (request.getHeaders() != null) {
this.headers = new HashMap<>(request.getHeaders());
}
} }
@SuppressWarnings("unchecked")
public final TransportResponse putHeader(String key, Object value) {
if (headers == null) {
headers = Maps.newHashMap();
}
headers.put(key, value);
return this;
}
@SuppressWarnings("unchecked")
public final <V> V getHeader(String key) {
if (headers == null) {
return null;
}
return (V) headers.get(key);
}
public Map<String, Object> getHeaders() {
return this.headers;
}
@Override
public void readFrom(StreamInput in) throws IOException {
if (in.readBoolean()) {
headers = in.readMap();
}
}
@Override
public void writeTo(StreamOutput out) throws IOException {
if (headers == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeMap(headers);
}
}
} }

View File

@ -217,7 +217,7 @@ public class LocalTransport extends AbstractLifecycleComponent<Transport> implem
if (TransportStatus.isError(status)) { if (TransportStatus.isError(status)) {
handlerResponseError(stream, handler); handlerResponseError(stream, handler);
} else { } else {
handleResponse(stream, handler); handleResponse(stream, sourceTransport, handler);
} }
} }
} }
@ -242,6 +242,7 @@ public class LocalTransport extends AbstractLifecycleComponent<Transport> implem
throw new ActionNotFoundTransportException("Action [" + action + "] not found"); throw new ActionNotFoundTransportException("Action [" + action + "] not found");
} }
final TransportRequest request = handler.newInstance(); final TransportRequest request = handler.newInstance();
request.remoteAddress(sourceTransport.boundAddress.publishAddress());
request.readFrom(stream); request.readFrom(stream);
if (handler.executor() == ThreadPool.Names.SAME) { if (handler.executor() == ThreadPool.Names.SAME) {
//noinspection unchecked //noinspection unchecked
@ -282,9 +283,9 @@ public class LocalTransport extends AbstractLifecycleComponent<Transport> implem
} }
} }
protected void handleResponse(StreamInput buffer, LocalTransport sourceTransport, final TransportResponseHandler handler) {
protected void handleResponse(StreamInput buffer, final TransportResponseHandler handler) {
final TransportResponse response = handler.newInstance(); final TransportResponse response = handler.newInstance();
response.remoteAddress(sourceTransport.boundAddress.publishAddress());
try { try {
response.readFrom(buffer); response.readFrom(buffer);
} catch (Throwable e) { } catch (Throwable e) {

View File

@ -28,6 +28,7 @@ import org.elasticsearch.common.io.ThrowableObjectInputStream;
import org.elasticsearch.common.io.stream.CachedStreamInput; import org.elasticsearch.common.io.stream.CachedStreamInput;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.logging.ESLogger;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.*; import org.elasticsearch.transport.*;
@ -36,6 +37,7 @@ import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.*; import org.jboss.netty.channel.*;
import java.io.IOException; import java.io.IOException;
import java.net.InetSocketAddress;
/** /**
* A handler (must be the last one!) that does size based frame decoding and forwards the actual message * A handler (must be the last one!) that does size based frame decoding and forwards the actual message
@ -122,7 +124,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
if (TransportStatus.isError(status)) { if (TransportStatus.isError(status)) {
handlerResponseError(wrappedStream, handler); handlerResponseError(wrappedStream, handler);
} else { } else {
handleResponse(wrappedStream, handler); handleResponse(ctx.getChannel(), wrappedStream, handler);
} }
} else { } else {
// if its null, skip those bytes // if its null, skip those bytes
@ -140,8 +142,10 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
wrappedStream.close(); wrappedStream.close();
} }
private void handleResponse(StreamInput buffer, final TransportResponseHandler handler) { private void handleResponse(Channel channel, StreamInput buffer, final TransportResponseHandler handler) {
final TransportResponse response = handler.newInstance(); final TransportResponse response = handler.newInstance();
response.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
response.remoteAddress();
try { try {
response.readFrom(buffer); response.readFrom(buffer);
} catch (Throwable e) { } catch (Throwable e) {
@ -206,6 +210,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
throw new ActionNotFoundTransportException(action); throw new ActionNotFoundTransportException(action);
} }
final TransportRequest request = handler.newInstance(); final TransportRequest request = handler.newInstance();
request.remoteAddress(new InetSocketTransportAddress((InetSocketAddress) channel.getRemoteAddress()));
request.readFrom(buffer); request.readFrom(buffer);
if (handler.executor() == ThreadPool.Names.SAME) { if (handler.executor() == ThreadPool.Names.SAME) {
//noinspection unchecked //noinspection unchecked

View File

@ -26,6 +26,7 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.test.ElasticsearchTestCase; import org.elasticsearch.test.ElasticsearchTestCase;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
@ -37,6 +38,7 @@ import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import static org.elasticsearch.transport.TransportRequestOptions.options; import static org.elasticsearch.transport.TransportRequestOptions.options;
import static org.hamcrest.Matchers.*; import static org.hamcrest.Matchers.*;
@ -1048,4 +1050,70 @@ public abstract class AbstractSimpleTransportTests extends ElasticsearchTestCase
serviceA.removeHandler("sayHello"); serviceA.removeHandler("sayHello");
} }
@Test
public void testHostOnMessages() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(2);
final AtomicReference<TransportAddress> addressA = new AtomicReference<>();
final AtomicReference<TransportAddress> addressB = new AtomicReference<>();
serviceB.registerHandler("action1", new TransportRequestHandler<TestRequest>() {
@Override
public TestRequest newInstance() {
return new TestRequest();
}
@Override
public void messageReceived(TestRequest request, TransportChannel channel) throws Exception {
latch.countDown();
addressA.set(request.remoteAddress());
channel.sendResponse(new TestResponse());
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
@Override
public boolean isForceExecution() {
return false;
}
});
serviceA.sendRequest(nodeB, "action1", new TestRequest(), new TransportResponseHandler<TestResponse>() {
@Override
public TestResponse newInstance() {
return new TestResponse();
}
@Override
public void handleResponse(TestResponse response) {
latch.countDown();
addressB.set(response.remoteAddress());
}
@Override
public void handleException(TransportException exp) {
latch.countDown();
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
});
if (!latch.await(10, TimeUnit.SECONDS)) {
fail("message round trip did not complete within a sensible time frame");
}
assertTrue(nodeA.address().sameHost(addressA.get()));
assertTrue(nodeB.address().sameHost(addressB.get()));
}
private static class TestRequest extends TransportRequest {
}
private static class TestResponse extends TransportResponse {
}
} }

View File

@ -24,7 +24,6 @@ import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.transport.AbstractSimpleTransportTests; import org.elasticsearch.transport.AbstractSimpleTransportTests;
import org.elasticsearch.transport.TransportService;
public class SimpleLocalTransportTests extends AbstractSimpleTransportTests { public class SimpleLocalTransportTests extends AbstractSimpleTransportTests {