From 278e5d3a430289b82a6c8a4ed560f3ec96691939 Mon Sep 17 00:00:00 2001 From: Shay Banon Date: Thu, 9 Feb 2012 00:15:08 +0200 Subject: [PATCH] Transport buffer overrun can happen because of byte buffer reading optimization introduced in 0.19.0.RC1, closes #1686. --- .../http/netty/NettyHttpRequest.java | 5 +- .../netty/MessageChannelHandler.java | 56 +++++++---- .../ConcurrentSearchSerializationTests.java | 93 +++++++++++++++++++ 3 files changed, 131 insertions(+), 23 deletions(-) create mode 100644 src/test/java/org/elasticsearch/test/stress/search1/ConcurrentSearchSerializationTests.java diff --git a/src/main/java/org/elasticsearch/http/netty/NettyHttpRequest.java b/src/main/java/org/elasticsearch/http/netty/NettyHttpRequest.java index e2defd0faa0..2ebea802872 100644 --- a/src/main/java/org/elasticsearch/http/netty/NettyHttpRequest.java +++ b/src/main/java/org/elasticsearch/http/netty/NettyHttpRequest.java @@ -108,10 +108,7 @@ public class NettyHttpRequest extends AbstractRestRequest implements HttpRequest @Override public boolean contentUnsafe() { - // the netty HTTP handling always copy over the buffer to its own buffer, either in NioWorker internally - // when reading, or using a cumalation buffer - - // also, HttpMessageDecoder#content variable gets freshly created for each request and not reused across + // HttpMessageDecoder#content variable gets freshly created for each request and not reused across // requests return false; //return request.getContent().hasArray(); diff --git a/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java b/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java index 18827611620..210f60cec40 100644 --- a/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java +++ b/src/main/java/org/elasticsearch/transport/netty/MessageChannelHandler.java @@ -34,7 +34,6 @@ import org.jboss.netty.channel.*; import java.io.IOException; import java.io.StreamCorruptedException; -import java.net.SocketAddress; /** * A handler (must be the last one!) that does size based frame decoding and forwards the actual message @@ -67,10 +66,12 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { } // similar logic to FrameDecoder, we don't use FrameDecoder because we can use the data len header value - // to guess the size of the cumulation buffer to allocate + // to guess the size of the cumulation buffer to allocate, and because we make a fresh copy of the cumulation + // buffer so we can readBytesReference from it without other request writing into the same one in case + // two one message and a partial next message exists within the same input - // we don't reuse the cumalation buffer, so it won't grow out of control per channel, as well as - // being able to "readBytesReference" from it without worry + // we can readBytesReference because NioWorker always copies the input buffer into a fresh buffer, and we + // don't reuse cumumlation buffer @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { @@ -89,9 +90,9 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { if (cumulation != null && cumulation.readable()) { cumulation.discardReadBytes(); cumulation.writeBytes(input); - callDecode(ctx, e.getChannel(), cumulation, e.getRemoteAddress()); + callDecode(ctx, e.getChannel(), cumulation, true); } else { - int actualSize = callDecode(ctx, e.getChannel(), input, e.getRemoteAddress()); + int actualSize = callDecode(ctx, e.getChannel(), input, false); if (input.readable()) { if (actualSize > 0) { cumulation = ChannelBuffers.dynamicBuffer(actualSize, ctx.getChannel().getConfig().getBufferFactory()); @@ -116,33 +117,50 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { cleanup(ctx, e); } - private int callDecode(ChannelHandlerContext context, Channel channel, ChannelBuffer cumulation, SocketAddress remoteAddress) throws Exception { - while (cumulation.readable()) { + private int callDecode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, boolean cumulationBuffer) throws Exception { + int actualSize = 0; + while (buffer.readable()) { + actualSize = 0; // Changes from Frame Decoder, to combine SizeHeader and this decoder into one... - if (cumulation.readableBytes() < 4) { + if (buffer.readableBytes() < 4) { break; // we need more data } - int dataLen = cumulation.getInt(cumulation.readerIndex()); + int dataLen = buffer.getInt(buffer.readerIndex()); if (dataLen <= 0) { throw new StreamCorruptedException("invalid data length: " + dataLen); } - int actualSize = dataLen + 4; - if (cumulation.readableBytes() < actualSize) { - return actualSize; + actualSize = dataLen + 4; + if (buffer.readableBytes() < actualSize) { + break; } - cumulation.skipBytes(4); + buffer.skipBytes(4); - process(context, channel, cumulation, dataLen); + process(ctx, channel, buffer, dataLen); } - if (!cumulation.readable()) { - this.cumulation = null; + if (cumulationBuffer) { + assert buffer == this.cumulation; + if (!buffer.readable()) { + this.cumulation = null; + } else if (buffer.readerIndex() > 0) { + // make a fresh copy of the cumalation buffer, so we + // can readBytesReference from it, and also, don't keep it around + + // its not that big of an overhead since discardReadBytes in the next round messageReceived will + // copy over the bytes to the start again + if (actualSize > 0) { + this.cumulation = ChannelBuffers.dynamicBuffer(actualSize, ctx.getChannel().getConfig().getBufferFactory()); + } else { + this.cumulation = ChannelBuffers.dynamicBuffer(ctx.getChannel().getConfig().getBufferFactory()); + } + this.cumulation.writeBytes(buffer); + } } - return 0; + return actualSize; } @@ -157,7 +175,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler { if (cumulation.readable()) { // Make sure all frames are read before notifying a closed channel. - callDecode(ctx, ctx.getChannel(), cumulation, null); + callDecode(ctx, ctx.getChannel(), cumulation, true); } // Call decodeLast() finally. Please note that decodeLast() is diff --git a/src/test/java/org/elasticsearch/test/stress/search1/ConcurrentSearchSerializationTests.java b/src/test/java/org/elasticsearch/test/stress/search1/ConcurrentSearchSerializationTests.java new file mode 100644 index 00000000000..e4e4aee23ac --- /dev/null +++ b/src/test/java/org/elasticsearch/test/stress/search1/ConcurrentSearchSerializationTests.java @@ -0,0 +1,93 @@ +package org.elasticsearch.test.stress.search1; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.RandomStringGenerator; +import org.elasticsearch.common.settings.ImmutableSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.node.Node; +import org.elasticsearch.node.NodeBuilder; +import org.elasticsearch.search.SearchHit; + +import java.util.concurrent.CountDownLatch; + +/** + * Tests that data don't get corrupted while reading it over the streams. + *

+ * See: https://github.com/elasticsearch/elasticsearch/issues/1686. + */ +public class ConcurrentSearchSerializationTests { + + public static void main(String[] args) throws Exception { + + Settings settings = ImmutableSettings.settingsBuilder().put("gateway.type", "none").build(); + + Node node1 = NodeBuilder.nodeBuilder().settings(settings).node(); + Node node2 = NodeBuilder.nodeBuilder().settings(settings).node(); + Node node3 = NodeBuilder.nodeBuilder().settings(settings).node(); + + final Client client = node1.client(); + + System.out.println("Indexing..."); + final String data = RandomStringGenerator.random(100); + final CountDownLatch latch1 = new CountDownLatch(100); + for (int i = 0; i < 100; i++) { + client.prepareIndex("test", "type", Integer.toString(i)) + .setSource("field", data) + .execute(new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + latch1.countDown(); + } + + @Override + public void onFailure(Throwable e) { + latch1.countDown(); + } + }); + } + latch1.await(); + System.out.println("Indexed"); + + System.out.println("searching..."); + Thread[] threads = new Thread[10]; + final CountDownLatch latch = new CountDownLatch(threads.length); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(new Runnable() { + @Override + public void run() { + for (int i = 0; i < 1000; i++) { + SearchResponse searchResponse = client.prepareSearch("test") + .setQuery(QueryBuilders.matchAllQuery()) + .setSize(i % 100) + .execute().actionGet(); + for (SearchHit hit : searchResponse.hits()) { + try { + if (!hit.sourceAsMap().get("field").equals(data)) { + System.err.println("Field not equal!"); + } + } catch (Exception e) { + e.printStackTrace(); + } + } + } + latch.countDown(); + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + + latch.await(); + + System.out.println("done searching"); + client.close(); + node1.close(); + node2.close(); + node3.close(); + } +}