Transport buffer overrun can happen because of byte buffer reading optimization introduced in 0.19.0.RC1, closes #1686.

This commit is contained in:
Shay Banon 2012-02-09 00:15:08 +02:00
parent a135c9bd8b
commit 278e5d3a43
3 changed files with 131 additions and 23 deletions

View File

@ -108,10 +108,7 @@ public class NettyHttpRequest extends AbstractRestRequest implements HttpRequest
@Override @Override
public boolean contentUnsafe() { public boolean contentUnsafe() {
// the netty HTTP handling always copy over the buffer to its own buffer, either in NioWorker internally // HttpMessageDecoder#content variable gets freshly created for each request and not reused across
// when reading, or using a cumalation buffer
// also, HttpMessageDecoder#content variable gets freshly created for each request and not reused across
// requests // requests
return false; return false;
//return request.getContent().hasArray(); //return request.getContent().hasArray();

View File

@ -34,7 +34,6 @@ import org.jboss.netty.channel.*;
import java.io.IOException; import java.io.IOException;
import java.io.StreamCorruptedException; import java.io.StreamCorruptedException;
import java.net.SocketAddress;
/** /**
* A handler (must be the last one!) that does size based frame decoding and forwards the actual message * A handler (must be the last one!) that does size based frame decoding and forwards the actual message
@ -67,10 +66,12 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
} }
// similar logic to FrameDecoder, we don't use FrameDecoder because we can use the data len header value // similar logic to FrameDecoder, we don't use FrameDecoder because we can use the data len header value
// to guess the size of the cumulation buffer to allocate // to guess the size of the cumulation buffer to allocate, and because we make a fresh copy of the cumulation
// buffer so we can readBytesReference from it without other request writing into the same one in case
// two one message and a partial next message exists within the same input
// we don't reuse the cumalation buffer, so it won't grow out of control per channel, as well as // we can readBytesReference because NioWorker always copies the input buffer into a fresh buffer, and we
// being able to "readBytesReference" from it without worry // don't reuse cumumlation buffer
@Override @Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
@ -89,9 +90,9 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
if (cumulation != null && cumulation.readable()) { if (cumulation != null && cumulation.readable()) {
cumulation.discardReadBytes(); cumulation.discardReadBytes();
cumulation.writeBytes(input); cumulation.writeBytes(input);
callDecode(ctx, e.getChannel(), cumulation, e.getRemoteAddress()); callDecode(ctx, e.getChannel(), cumulation, true);
} else { } else {
int actualSize = callDecode(ctx, e.getChannel(), input, e.getRemoteAddress()); int actualSize = callDecode(ctx, e.getChannel(), input, false);
if (input.readable()) { if (input.readable()) {
if (actualSize > 0) { if (actualSize > 0) {
cumulation = ChannelBuffers.dynamicBuffer(actualSize, ctx.getChannel().getConfig().getBufferFactory()); cumulation = ChannelBuffers.dynamicBuffer(actualSize, ctx.getChannel().getConfig().getBufferFactory());
@ -116,33 +117,50 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
cleanup(ctx, e); cleanup(ctx, e);
} }
private int callDecode(ChannelHandlerContext context, Channel channel, ChannelBuffer cumulation, SocketAddress remoteAddress) throws Exception { private int callDecode(ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, boolean cumulationBuffer) throws Exception {
while (cumulation.readable()) { int actualSize = 0;
while (buffer.readable()) {
actualSize = 0;
// Changes from Frame Decoder, to combine SizeHeader and this decoder into one... // Changes from Frame Decoder, to combine SizeHeader and this decoder into one...
if (cumulation.readableBytes() < 4) { if (buffer.readableBytes() < 4) {
break; // we need more data break; // we need more data
} }
int dataLen = cumulation.getInt(cumulation.readerIndex()); int dataLen = buffer.getInt(buffer.readerIndex());
if (dataLen <= 0) { if (dataLen <= 0) {
throw new StreamCorruptedException("invalid data length: " + dataLen); throw new StreamCorruptedException("invalid data length: " + dataLen);
} }
int actualSize = dataLen + 4; actualSize = dataLen + 4;
if (cumulation.readableBytes() < actualSize) { if (buffer.readableBytes() < actualSize) {
return actualSize; break;
} }
cumulation.skipBytes(4); buffer.skipBytes(4);
process(context, channel, cumulation, dataLen); process(ctx, channel, buffer, dataLen);
} }
if (!cumulation.readable()) { if (cumulationBuffer) {
this.cumulation = null; assert buffer == this.cumulation;
if (!buffer.readable()) {
this.cumulation = null;
} else if (buffer.readerIndex() > 0) {
// make a fresh copy of the cumalation buffer, so we
// can readBytesReference from it, and also, don't keep it around
// its not that big of an overhead since discardReadBytes in the next round messageReceived will
// copy over the bytes to the start again
if (actualSize > 0) {
this.cumulation = ChannelBuffers.dynamicBuffer(actualSize, ctx.getChannel().getConfig().getBufferFactory());
} else {
this.cumulation = ChannelBuffers.dynamicBuffer(ctx.getChannel().getConfig().getBufferFactory());
}
this.cumulation.writeBytes(buffer);
}
} }
return 0; return actualSize;
} }
@ -157,7 +175,7 @@ public class MessageChannelHandler extends SimpleChannelUpstreamHandler {
if (cumulation.readable()) { if (cumulation.readable()) {
// Make sure all frames are read before notifying a closed channel. // Make sure all frames are read before notifying a closed channel.
callDecode(ctx, ctx.getChannel(), cumulation, null); callDecode(ctx, ctx.getChannel(), cumulation, true);
} }
// Call decodeLast() finally. Please note that decodeLast() is // Call decodeLast() finally. Please note that decodeLast() is

View File

@ -0,0 +1,93 @@
package org.elasticsearch.test.stress.search1;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.RandomStringGenerator;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.node.Node;
import org.elasticsearch.node.NodeBuilder;
import org.elasticsearch.search.SearchHit;
import java.util.concurrent.CountDownLatch;
/**
* Tests that data don't get corrupted while reading it over the streams.
* <p/>
* See: https://github.com/elasticsearch/elasticsearch/issues/1686.
*/
public class ConcurrentSearchSerializationTests {
public static void main(String[] args) throws Exception {
Settings settings = ImmutableSettings.settingsBuilder().put("gateway.type", "none").build();
Node node1 = NodeBuilder.nodeBuilder().settings(settings).node();
Node node2 = NodeBuilder.nodeBuilder().settings(settings).node();
Node node3 = NodeBuilder.nodeBuilder().settings(settings).node();
final Client client = node1.client();
System.out.println("Indexing...");
final String data = RandomStringGenerator.random(100);
final CountDownLatch latch1 = new CountDownLatch(100);
for (int i = 0; i < 100; i++) {
client.prepareIndex("test", "type", Integer.toString(i))
.setSource("field", data)
.execute(new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
latch1.countDown();
}
@Override
public void onFailure(Throwable e) {
latch1.countDown();
}
});
}
latch1.await();
System.out.println("Indexed");
System.out.println("searching...");
Thread[] threads = new Thread[10];
final CountDownLatch latch = new CountDownLatch(threads.length);
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 1000; i++) {
SearchResponse searchResponse = client.prepareSearch("test")
.setQuery(QueryBuilders.matchAllQuery())
.setSize(i % 100)
.execute().actionGet();
for (SearchHit hit : searchResponse.hits()) {
try {
if (!hit.sourceAsMap().get("field").equals(data)) {
System.err.println("Field not equal!");
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
latch.countDown();
}
});
}
for (Thread thread : threads) {
thread.start();
}
latch.await();
System.out.println("done searching");
client.close();
node1.close();
node2.close();
node3.close();
}
}