diff --git a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java index da1dcf43e5d..6098ee9f7da 100644 --- a/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TCPTransportTests.java @@ -19,8 +19,25 @@ package org.elasticsearch.transport; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.compress.CompressorFactory; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.support.TransportStatus; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; /** Unit tests for TCPTransport */ public class TCPTransportTests extends ESTestCase { @@ -127,4 +144,102 @@ public class TCPTransportTests extends ESTestCase { assertEquals(101, addresses[1].getPort()); assertEquals(102, addresses[2].getPort()); } + + public void testCompressRequest() throws IOException { + final boolean compressed = randomBoolean(); + final AtomicBoolean called = new AtomicBoolean(false); + Req request = new Req(randomRealisticUnicodeOfLengthBetween(10, 100)); + ThreadPool threadPool = new TestThreadPool(TCPTransportTests.class.getName()); + try { + TcpTransport transport = new TcpTransport("test", Settings.builder().put("transport.tcp.compress", compressed).build(), + threadPool, new BigArrays(Settings.EMPTY, null), null, null, null) { + @Override + protected InetSocketAddress getLocalAddress(Object o) { + return null; + } + + @Override + protected Object bind(String name, InetSocketAddress address) throws IOException { + return null; + } + + @Override + protected void closeChannels(List channel) throws IOException { + + } + + @Override + protected NodeChannels connectToChannelsLight(DiscoveryNode node) throws IOException { + return new NodeChannels(new Object[0], new Object[0], new Object[0], new Object[0], new Object[0]); + } + + @Override + protected void sendMessage(Object o, BytesReference reference, Runnable sendListener) throws IOException { + StreamInput streamIn = reference.streamInput(); + streamIn.skip(TcpHeader.MARKER_BYTES_SIZE); + int len = streamIn.readInt(); + long requestId = streamIn.readLong(); + assertEquals(42, requestId); + byte status = streamIn.readByte(); + Version version = Version.fromId(streamIn.readInt()); + assertEquals(Version.CURRENT, version); + assertEquals(compressed, TransportStatus.isCompress(status)); + called.compareAndSet(false, true); + if (compressed) { + final int bytesConsumed = TcpHeader.HEADER_SIZE; + streamIn = CompressorFactory.compressor(reference.slice(bytesConsumed, reference.length() - bytesConsumed)).streamInput(streamIn); + } + threadPool.getThreadContext().readHeaders(streamIn); + assertEquals("foobar", streamIn.readString()); + Req readReq = new Req(""); + readReq.readFrom(streamIn); + assertEquals(request.value, readReq.value); + } + + @Override + protected NodeChannels connectToChannels(DiscoveryNode node) throws IOException { + return new NodeChannels(new Object[0], new Object[0], new Object[0], new Object[0], new Object[0]); + } + + @Override + protected boolean isOpen(Object o) { + return false; + } + + @Override + public long serverOpen() { + return 0; + } + + @Override + protected Object nodeChannel(DiscoveryNode node, TransportRequestOptions options) throws ConnectTransportException { + return new NodeChannels(new Object[0], new Object[0], new Object[0], new Object[0], new Object[0]); + } + }; + DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT); + transport.sendRequest(node, 42, "foobar", request, TransportRequestOptions.EMPTY); + assertTrue(called.get()); + } finally { + ThreadPool.terminate(threadPool, 10, TimeUnit.SECONDS); + } + } + + private static final class Req extends TransportRequest { + public String value; + + private Req(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } + }