diff --git a/core/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/core/src/main/java/org/elasticsearch/transport/TransportActionProxy.java new file mode 100644 index 00000000000..38ed2bbad73 --- /dev/null +++ b/core/src/main/java/org/elasticsearch/transport/TransportActionProxy.java @@ -0,0 +1,123 @@ +package org.elasticsearch.transport; + +import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.io.UncheckedIOException; + +// nocommit javadocs +public final class TransportActionProxy { + + private TransportActionProxy() {} // no instance + + private static class ProxyRequestHandler implements TransportRequestHandler { + + private final TransportService service; + private final String action; + private final Supplier responseFactory; + + public ProxyRequestHandler(TransportService service, String action, Supplier responseFactory) { + this.service = service; + this.action = action; + this.responseFactory = responseFactory; + } + + @Override + public void messageReceived(T request, TransportChannel channel) throws Exception { + DiscoveryNode targetNode = request.targetNode; + TransportRequest wrappedRequest = request.wrapped; + service.sendRequest(targetNode, action, wrappedRequest, new ProxyResponseHandler<>(channel, responseFactory)); + } + } + + private static class ProxyResponseHandler implements TransportResponseHandler { + + private final Supplier responseFactory; + private final TransportChannel channel; + + public ProxyResponseHandler(TransportChannel channel, Supplier responseFactory) { + this.responseFactory = responseFactory; + this.channel = channel; + + } + @Override + public T newInstance() { + return responseFactory.get(); + } + + @Override + public void handleResponse(T response) { + try { + channel.sendResponse(response); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public void handleException(TransportException exp) { + try { + channel.sendResponse(exp); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + } + + static class ProxyRequest extends TransportRequest { + T wrapped; + Supplier supplier; + DiscoveryNode targetNode; + + public ProxyRequest(Supplier supplier) { + this.supplier = supplier; + } + + public ProxyRequest(T wrapped, DiscoveryNode targetNode) { + this.wrapped = wrapped; + this.targetNode = targetNode; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + targetNode = new DiscoveryNode(in); + wrapped = supplier.get(); + wrapped.readFrom(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + targetNode.writeTo(out); + wrapped.writeTo(out); + } + } + + /** + * Registers a proxy request handler that allows to forward requests for the given action to another node. + */ + public static String registerProxyAction(TransportService service, String action, Supplier responseSupplier) { + RequestHandlerRegistry requestHandler = service.getRequestHandler(action); + String proxyAction = "internal:transport/proxy/" + action; + service.registerRequestHandler(proxyAction, () -> new ProxyRequest(requestHandler::newRequest), ThreadPool.Names.SAME, true, false + , new ProxyRequestHandler<>(service, action, responseSupplier)); + return proxyAction; + } + + //nocommit javadocs + public static void sendProxyRequest(TransportService service, DiscoveryNode proxyNode, DiscoveryNode targetNode, String action, + TransportRequest request, TransportResponseHandler handler) { + String proxyAction = "internal:transport/proxy/" + action; + service.sendRequest(proxyNode, proxyAction, new ProxyRequest(request, targetNode), handler); + } +} diff --git a/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java new file mode 100644 index 00000000000..be4e7fd3ae4 --- /dev/null +++ b/core/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -0,0 +1,252 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.transport; + +import org.apache.lucene.util.IOUtils; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.Before; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; + +public class TransportActionProxyTests extends ESTestCase { + protected ThreadPool threadPool; + // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions + private static final Version CURRENT_VERSION = Version.fromString(String.valueOf(Version.CURRENT.major) + ".0.0"); + protected static final Version version0 = CURRENT_VERSION.minimumCompatibilityVersion(); + + protected DiscoveryNode nodeA; + protected MockTransportService serviceA; + + protected static final Version version1 = Version.fromId(CURRENT_VERSION.id + 1); + protected DiscoveryNode nodeB; + protected MockTransportService serviceB; + + protected DiscoveryNode nodeC; + protected MockTransportService serviceC; + + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = new TestThreadPool(getClass().getName()); + serviceA = buildService(version0); // this one supports dynamic tracer updates + nodeA = new DiscoveryNode("TS_A", serviceA.boundAddress().publishAddress(), emptyMap(), emptySet(), version0); + serviceB = buildService(version1); // this one doesn't support dynamic tracer updates + nodeB = new DiscoveryNode("TS_B", serviceB.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + serviceC = buildService(version1); // this one doesn't support dynamic tracer updates + nodeC = new DiscoveryNode("TS_C", serviceC.boundAddress().publishAddress(), emptyMap(), emptySet(), version1); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + IOUtils.close(serviceA, serviceB, serviceC, () -> { + try { + terminate(threadPool); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + } + + private MockTransportService buildService(final Version version) { + MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, version, threadPool, null); + service.start(); + service.acceptIncomingRequests(); + return service; + + } + + + public void testSendMessage() throws InterruptedException { + serviceA.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + assertEquals(request.sourceNode, "TS_A"); + SimpleTestResponse response = new SimpleTestResponse(); + response.targetNode = "TS_A"; + channel.sendResponse(response); + }); + TransportActionProxy.registerProxyAction(serviceA, "/test", SimpleTestResponse::new); + serviceA.connectToNode(nodeB); + + serviceB.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + assertEquals(request.sourceNode, "TS_A"); + SimpleTestResponse response = new SimpleTestResponse(); + response.targetNode = "TS_B"; + channel.sendResponse(response); + }); + TransportActionProxy.registerProxyAction(serviceB, "/test", SimpleTestResponse::new); + serviceB.connectToNode(nodeC); + serviceC.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + assertEquals(request.sourceNode, "TS_A"); + SimpleTestResponse response = new SimpleTestResponse(); + response.targetNode = "TS_C"; + channel.sendResponse(response); + }); + TransportActionProxy.registerProxyAction(serviceC, "/test", SimpleTestResponse::new); + + CountDownLatch latch = new CountDownLatch(1); + TransportActionProxy.sendProxyRequest(serviceA, nodeB, nodeC, "/test", new SimpleTestRequest("TS_A"), + new TransportResponseHandler() { + @Override + public SimpleTestResponse newInstance() { + return new SimpleTestResponse(); + } + + @Override + public void handleResponse(SimpleTestResponse response) { + try { + assertEquals("TS_C", response.targetNode); + } finally { + latch.countDown(); + } + } + + @Override + public void handleException(TransportException exp) { + try { + throw new AssertionError(exp); + } finally { + latch.countDown(); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); + latch.await(); + } + + public void testException() throws InterruptedException { + serviceA.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + assertEquals(request.sourceNode, "TS_A"); + SimpleTestResponse response = new SimpleTestResponse(); + response.targetNode = "TS_A"; + channel.sendResponse(response); + }); + TransportActionProxy.registerProxyAction(serviceA, "/test", SimpleTestResponse::new); + serviceA.connectToNode(nodeB); + + serviceB.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + assertEquals(request.sourceNode, "TS_A"); + SimpleTestResponse response = new SimpleTestResponse(); + response.targetNode = "TS_B"; + channel.sendResponse(response); + }); + TransportActionProxy.registerProxyAction(serviceB, "/test", SimpleTestResponse::new); + serviceB.connectToNode(nodeC); + serviceC.registerRequestHandler("/test", SimpleTestRequest::new, ThreadPool.Names.SAME, + (request, channel) -> { + throw new ElasticsearchException("greetings from TS_C"); + }); + TransportActionProxy.registerProxyAction(serviceC, "/test", SimpleTestResponse::new); + + CountDownLatch latch = new CountDownLatch(1); + TransportActionProxy.sendProxyRequest(serviceA, nodeB, nodeC, "/test", new SimpleTestRequest("TS_A"), + new TransportResponseHandler() { + @Override + public SimpleTestResponse newInstance() { + return new SimpleTestResponse(); + } + + @Override + public void handleResponse(SimpleTestResponse response) { + try { + fail("expected exception"); + } finally { + latch.countDown(); + } + } + + @Override + public void handleException(TransportException exp) { + try { + Throwable cause = ExceptionsHelper.unwrapCause(exp); + assertEquals("greetings from TS_C", cause.getMessage()); + } finally { + latch.countDown(); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); + latch.await(); + } + + public static class SimpleTestRequest extends TransportRequest { + String sourceNode; + + public SimpleTestRequest(String sourceNode) { + this.sourceNode = sourceNode; + } + public SimpleTestRequest() {} + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + sourceNode = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(sourceNode); + } + } + + public static class SimpleTestResponse extends TransportResponse { + String targetNode; + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + targetNode = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(targetNode); + } + } + +}