459542 - AsyncMiddleManServlet race condition on first download content.

Fixed the race condition by submitting a zero length buffer to write
from onWritePossible() which will succeed the callback without
causing races.
This commit is contained in:
Simone Bordet 2015-02-10 13:10:11 +01:00
parent 7b00f6857f
commit 12e2f9e6c8
5 changed files with 226 additions and 125 deletions

View File

@ -35,6 +35,7 @@ import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader; import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.CountingCallback;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger; import org.eclipse.jetty.util.log.Logger;
@ -335,17 +336,10 @@ public abstract class HttpReceiver
} }
else else
{ {
Callback partial = new Callback.Adapter() int size = decodeds.size();
{ CountingCallback counter = new CountingCallback(callback, size);
@Override for (int i = 0; i < size; ++i)
public void failed(Throwable x) notifier.notifyContent(listeners, response, decodeds.get(i), counter);
{
callback.failed(x);
}
};
for (int i = 1, size = decodeds.size(); i <= size; ++i)
notifier.notifyContent(listeners, response, decodeds.get(i - 1), i < size ? partial : callback);
} }
} }

View File

@ -87,10 +87,10 @@ import org.eclipse.jetty.util.Callback;
*/ */
public class DeferredContentProvider implements AsyncContentProvider, Callback, Closeable public class DeferredContentProvider implements AsyncContentProvider, Callback, Closeable
{ {
private static final AsyncChunk CLOSE = new AsyncChunk(BufferUtil.EMPTY_BUFFER, Callback.Adapter.INSTANCE); private static final Chunk CLOSE = new Chunk(BufferUtil.EMPTY_BUFFER, Callback.Adapter.INSTANCE);
private final Object lock = this; private final Object lock = this;
private final Queue<AsyncChunk> chunks = new ArrayQueue<>(4, 64, lock); private final Queue<Chunk> chunks = new ArrayQueue<>(4, 64, lock);
private final AtomicReference<Listener> listener = new AtomicReference<>(); private final AtomicReference<Listener> listener = new AtomicReference<>();
private final DeferredContentProviderIterator iterator = new DeferredContentProviderIterator(); private final DeferredContentProviderIterator iterator = new DeferredContentProviderIterator();
private final AtomicBoolean closed = new AtomicBoolean(); private final AtomicBoolean closed = new AtomicBoolean();
@ -121,7 +121,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
synchronized (lock) synchronized (lock)
{ {
long total = 0; long total = 0;
for (AsyncChunk chunk : chunks) for (Chunk chunk : chunks)
total += chunk.buffer.remaining(); total += chunk.buffer.remaining();
length = total; length = total;
} }
@ -148,10 +148,10 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
public boolean offer(ByteBuffer buffer, Callback callback) public boolean offer(ByteBuffer buffer, Callback callback)
{ {
return offer(new AsyncChunk(buffer, callback)); return offer(new Chunk(buffer, callback));
} }
private boolean offer(AsyncChunk chunk) private boolean offer(Chunk chunk)
{ {
Throwable failure; Throwable failure;
boolean result = false; boolean result = false;
@ -243,7 +243,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
private class DeferredContentProviderIterator implements Iterator<ByteBuffer>, Callback private class DeferredContentProviderIterator implements Iterator<ByteBuffer>, Callback
{ {
private AsyncChunk current; private Chunk current;
@Override @Override
public boolean hasNext() public boolean hasNext()
@ -259,7 +259,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
{ {
synchronized (lock) synchronized (lock)
{ {
AsyncChunk chunk = current = chunks.poll(); Chunk chunk = current = chunks.poll();
if (chunk == CLOSE) if (chunk == CLOSE)
throw new NoSuchElementException(); throw new NoSuchElementException();
return chunk == null ? null : chunk.buffer; return chunk == null ? null : chunk.buffer;
@ -275,7 +275,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
@Override @Override
public void succeeded() public void succeeded()
{ {
AsyncChunk chunk; Chunk chunk;
synchronized (lock) synchronized (lock)
{ {
chunk = current; chunk = current;
@ -292,7 +292,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
@Override @Override
public void failed(Throwable x) public void failed(Throwable x)
{ {
List<AsyncChunk> chunks = new ArrayList<>(); List<Chunk> chunks = new ArrayList<>();
synchronized (lock) synchronized (lock)
{ {
failure = x; failure = x;
@ -302,20 +302,26 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
current = null; current = null;
lock.notify(); lock.notify();
} }
for (AsyncChunk chunk : chunks) for (Chunk chunk : chunks)
chunk.callback.failed(x); chunk.callback.failed(x);
} }
} }
private static class AsyncChunk public static class Chunk
{ {
private final ByteBuffer buffer; public final ByteBuffer buffer;
private final Callback callback; public final Callback callback;
private AsyncChunk(ByteBuffer buffer, Callback callback) public Chunk(ByteBuffer buffer, Callback callback)
{ {
this.buffer = Objects.requireNonNull(buffer); this.buffer = Objects.requireNonNull(buffer);
this.callback = Objects.requireNonNull(callback); this.callback = Objects.requireNonNull(callback);
} }
@Override
public String toString()
{
return String.format("%s@%x", getClass().getSimpleName(), hashCode());
}
} }
} }

View File

@ -25,7 +25,6 @@ import java.nio.ByteBuffer;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Objects;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -51,6 +50,7 @@ import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.io.RuntimeIOException; import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.CountingCallback;
import org.eclipse.jetty.util.IteratingCallback; import org.eclipse.jetty.util.IteratingCallback;
public class AsyncMiddleManServlet extends AbstractProxyServlet public class AsyncMiddleManServlet extends AbstractProxyServlet
@ -171,14 +171,6 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
protected class ProxyReader extends IteratingCallback implements ReadListener protected class ProxyReader extends IteratingCallback implements ReadListener
{ {
private final Callback failer = new Adapter()
{
@Override
public void failed(Throwable x)
{
onError(x);
}
};
private final byte[] buffer = new byte[getHttpClient().getRequestBufferSize()]; private final byte[] buffer = new byte[getHttpClient().getRequestBufferSize()];
private final List<ByteBuffer> buffers = new ArrayList<>(); private final List<ByteBuffer> buffers = new ArrayList<>();
private final HttpServletRequest clientRequest; private final HttpServletRequest clientRequest;
@ -207,7 +199,16 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
public void onAllDataRead() throws IOException public void onAllDataRead() throws IOException
{ {
if (!provider.isClosed()) if (!provider.isClosed())
process(BufferUtil.EMPTY_BUFFER, failer, true); {
process(BufferUtil.EMPTY_BUFFER, new Adapter()
{
@Override
public void failed(Throwable x)
{
onError(x);
}
}, true);
}
if (_log.isDebugEnabled()) if (_log.isDebugEnabled())
_log.debug("{} proxying content to upstream completed", getRequestId(clientRequest)); _log.debug("{} proxying content to upstream completed", getRequestId(clientRequest));
@ -264,21 +265,28 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
clientRequest.setAttribute(CLIENT_TRANSFORMER, transformer); clientRequest.setAttribute(CLIENT_TRANSFORMER, transformer);
} }
if (content.hasRemaining() || finished) if (!content.hasRemaining() && !finished)
{ {
int contentBytes = content.remaining(); callback.succeeded();
return;
}
int contentBytes = content.remaining();
transformer.transform(content, finished, buffers); transformer.transform(content, finished, buffers);
int newContentBytes = 0; int newContentBytes = 0;
int size = buffers.size(); int size = buffers.size();
if (size > 0)
{
CountingCallback counter = new CountingCallback(callback, size);
for (int i = 0; i < size; ++i) for (int i = 0; i < size; ++i)
{ {
ByteBuffer buffer = buffers.get(i); ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining(); newContentBytes += buffer.remaining();
provider.offer(buffer, i == size - 1 ? callback : failer); provider.offer(buffer, counter);
} }
buffers.clear(); buffers.clear();
}
if (finished) if (finished)
provider.close(); provider.close();
@ -295,7 +303,6 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
if (size == 0) if (size == 0)
succeeded(); succeeded();
} }
}
@Override @Override
protected void onCompleteFailure(Throwable x) protected void onCompleteFailure(Throwable x)
@ -368,18 +375,25 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
int newContentBytes = 0; int newContentBytes = 0;
int size = buffers.size(); int size = buffers.size();
if (size > 0)
{
CountingCallback counter = new CountingCallback(callback, size);
for (int i = 0; i < size; ++i) for (int i = 0; i < size; ++i)
{ {
ByteBuffer buffer = buffers.get(i); ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining(); newContentBytes += buffer.remaining();
proxyWriter.offer(buffer, i == size - 1 ? callback : Callback.Adapter.INSTANCE); proxyWriter.offer(buffer, counter);
} }
buffers.clear(); buffers.clear();
}
if (_log.isDebugEnabled()) if (_log.isDebugEnabled())
_log.debug("{} downstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes); _log.debug("{} downstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes);
if (committed) if (committed)
{ {
if (size == 0)
callback.succeeded();
else
proxyWriter.onWritePossible(); proxyWriter.onWritePossible();
} }
else else
@ -389,11 +403,15 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
// Setting the WriteListener triggers an invocation to // Setting the WriteListener triggers an invocation to
// onWritePossible(), possibly on a different thread. // onWritePossible(), possibly on a different thread.
// We cannot succeed the callback from here, otherwise
// we run into a race where the different thread calls
// onWritePossible() and succeeding the callback causes
// this method to be called again, which also may call
// onWritePossible(). We use a poison pill for this case.
if (size == 0)
proxyWriter.offer(BufferUtil.EMPTY_BUFFER, callback);
proxyResponse.getOutputStream().setWriteListener(proxyWriter); proxyResponse.getOutputStream().setWriteListener(proxyWriter);
} }
if (size == 0)
callback.succeeded();
} }
catch (Throwable x) catch (Throwable x)
{ {
@ -417,11 +435,9 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
long newContentBytes = 0; long newContentBytes = 0;
int size = buffers.size(); int size = buffers.size();
for (int i = 0; i < size; ++i) if (size > 0)
{ {
ByteBuffer buffer = buffers.get(i); Callback callback = new Callback.Adapter()
newContentBytes += buffer.remaining();
proxyWriter.offer(buffer, i == size - 1 ? new Callback.Adapter()
{ {
@Override @Override
public void failed(Throwable x) public void failed(Throwable x)
@ -429,9 +445,15 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
if (complete.compareAndSet(false, true)) if (complete.compareAndSet(false, true))
onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x); onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x);
} }
} : Callback.Adapter.INSTANCE); };
for (int i = 0; i < size; ++i)
{
ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining();
proxyWriter.offer(buffer, callback);
} }
buffers.clear(); buffers.clear();
}
if (_log.isDebugEnabled()) if (_log.isDebugEnabled())
_log.debug("{} downstream content transformation to {} bytes", getRequestId(clientRequest), newContentBytes); _log.debug("{} downstream content transformation to {} bytes", getRequestId(clientRequest), newContentBytes);
@ -462,10 +484,10 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
protected class ProxyWriter implements WriteListener protected class ProxyWriter implements WriteListener
{ {
private final Queue<AsyncChunk> chunks = new ArrayDeque<>(); private final Queue<DeferredContentProvider.Chunk> chunks = new ArrayDeque<>();
private final HttpServletRequest clientRequest; private final HttpServletRequest clientRequest;
private final Response serverResponse; private final Response serverResponse;
private AsyncChunk chunk; private DeferredContentProvider.Chunk chunk;
private boolean writePending; private boolean writePending;
protected ProxyWriter(HttpServletRequest clientRequest, Response serverResponse) protected ProxyWriter(HttpServletRequest clientRequest, Response serverResponse)
@ -478,48 +500,48 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
{ {
if (_log.isDebugEnabled()) if (_log.isDebugEnabled())
_log.debug("{} proxying content to downstream: {} bytes", getRequestId(clientRequest), content.remaining()); _log.debug("{} proxying content to downstream: {} bytes", getRequestId(clientRequest), content.remaining());
return chunks.offer(new AsyncChunk(content, callback)); return chunks.offer(new DeferredContentProvider.Chunk(content, callback));
} }
@Override @Override
public void onWritePossible() throws IOException public void onWritePossible() throws IOException
{ {
ServletOutputStream output = clientRequest.getAsyncContext().getResponse().getOutputStream(); ServletOutputStream output = clientRequest.getAsyncContext().getResponse().getOutputStream();
while (true)
{ // If we had a pending write, let's succeed it.
if (writePending) if (writePending)
{ {
// The write was pending but is now complete. if (_log.isDebugEnabled())
_log.debug("{} pending async write complete of {} on {}", getRequestId(clientRequest), chunk, output);
writePending = false; writePending = false;
if (_log.isDebugEnabled())
_log.debug("{} pending async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output);
if (succeed(chunk.callback)) if (succeed(chunk.callback))
break; return;
} }
else
int length = 0;
DeferredContentProvider.Chunk chunk = null;
while (output.isReady())
{ {
chunk = chunks.poll(); if (chunk != null)
{
if (_log.isDebugEnabled())
_log.debug("{} async write complete of {} ({} bytes) on {}", getRequestId(clientRequest), chunk, length, output);
if (succeed(chunk.callback))
return;
}
this.chunk = chunk = chunks.poll();
if (chunk == null) if (chunk == null)
break; return;
length = chunk.buffer.remaining();
if (length > 0)
writeProxyResponseContent(output, chunk.buffer); writeProxyResponseContent(output, chunk.buffer);
}
if (output.isReady())
{
if (_log.isDebugEnabled()) if (_log.isDebugEnabled())
_log.debug("{} async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output); _log.debug("{} async write pending of {} ({} bytes) on {}", getRequestId(clientRequest), chunk, length, output);
if (succeed(chunk.callback))
break;
}
else
{
writePending = true; writePending = true;
if (_log.isDebugEnabled())
_log.debug("{} async write pending of {} bytes on {}", getRequestId(clientRequest), chunk.length, output);
break;
}
}
}
} }
private boolean succeed(Callback callback) private boolean succeed(Callback callback)
@ -533,7 +555,7 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
// which may remain pending, which means that the reentrant call // which may remain pending, which means that the reentrant call
// to onWritePossible() returns all the way back to just after the // to onWritePossible() returns all the way back to just after the
// succeed of the callback. There, we cannot just loop attempting // succeed of the callback. There, we cannot just loop attempting
// write, but we need to check whether we are still write pending. // write, but we need to check whether we are write pending.
callback.succeeded(); callback.succeeded();
return writePending; return writePending;
} }
@ -541,7 +563,7 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
@Override @Override
public void onError(Throwable failure) public void onError(Throwable failure)
{ {
AsyncChunk chunk = this.chunk; DeferredContentProvider.Chunk chunk = this.chunk;
if (chunk != null) if (chunk != null)
chunk.callback.failed(failure); chunk.callback.failed(failure);
else else
@ -549,21 +571,6 @@ public class AsyncMiddleManServlet extends AbstractProxyServlet
} }
} }
// TODO: coalesce this class with the one from DeferredContentProvider ?
private static class AsyncChunk
{
private final ByteBuffer buffer;
private final Callback callback;
private final int length;
private AsyncChunk(ByteBuffer buffer, Callback callback)
{
this.buffer = Objects.requireNonNull(buffer);
this.callback = Objects.requireNonNull(callback);
this.length = buffer.remaining();
}
}
/** /**
* <p>Allows applications to transform upstream and downstream content.</p> * <p>Allows applications to transform upstream and downstream content.</p>
* <p>Typical use cases of transformations are URL rewriting of HTML anchors * <p>Typical use cases of transformations are URL rewriting of HTML anchors

View File

@ -99,7 +99,7 @@ public class ProxyServletTest
{ {
private static final String PROXIED_HEADER = "X-Proxied"; private static final String PROXIED_HEADER = "X-Proxied";
@Parameterized.Parameters @Parameterized.Parameters(name = "{0}")
public static Iterable<Object[]> data() public static Iterable<Object[]> data()
{ {
return Arrays.asList(new Object[][]{ return Arrays.asList(new Object[][]{

View File

@ -0,0 +1,94 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.util;
import java.util.concurrent.atomic.AtomicInteger;
/**
* <p>A callback wrapper that succeeds the wrapped callback when the count is
* reached, or on first failure.</p>
* <p>This callback is particularly useful when an async operation is split
* into multiple parts, for example when an original byte buffer that needs
* to be written, along with a callback, is split into multiple byte buffers,
* since it allows the original callback to be wrapped and notified only when
* the last part has been processed.</p>
* <p>Example:</p>
* <pre>
* public void process(EndPoint endPoint, ByteBuffer buffer, Callback callback)
* {
* ByteBuffer[] buffers = split(buffer);
* CountCallback countCallback = new CountCallback(callback, buffers.length);
* endPoint.write(countCallback, buffers);
* }
* </pre>
*/
public class CountingCallback implements Callback
{
private final Callback callback;
private final AtomicInteger count;
public CountingCallback(Callback callback, int count)
{
this.callback = callback;
this.count = new AtomicInteger(count);
}
@Override
public void succeeded()
{
// Forward success on the last success.
while (true)
{
int current = count.get();
// Already completed ?
if (current == 0)
return;
if (count.compareAndSet(current, current - 1))
{
if (current == 1)
{
callback.succeeded();
return;
}
}
}
}
@Override
public void failed(Throwable failure)
{
// Forward failure on the first failure.
while (true)
{
int current = count.get();
// Already completed ?
if (current == 0)
return;
if (count.compareAndSet(current, 0))
{
callback.failed(failure);
return;
}
}
}
}