Merged branch 'jetty-9.2.x' into 'master'.

This commit is contained in:
Simone Bordet 2015-01-29 14:11:21 +01:00
commit 961a90d16c
15 changed files with 2387 additions and 578 deletions

View File

@ -20,7 +20,6 @@ package org.eclipse.jetty.alpn.server;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import org.eclipse.jetty.alpn.ALPN; import org.eclipse.jetty.alpn.ALPN;
@ -54,31 +53,38 @@ public class ALPNServerConnection extends NegotiatingServerConnection implements
List<String> serverProtocols = getProtocols(); List<String> serverProtocols = getProtocols();
String tlsProtocol = sslEngine.getHandshakeSession().getProtocol(); String tlsProtocol = sslEngine.getHandshakeSession().getProtocol();
String tlsCipher = sslEngine.getHandshakeSession().getCipherSuite(); String tlsCipher = sslEngine.getHandshakeSession().getCipherSuite();
String negotiated = null; String negotiated = null;
for (String clientProtocol : clientProtocols)
// RFC 7301 states that the server picks the protocol
// that it prefers that is also supported by the client.
for (String serverProtocol : serverProtocols)
{ {
if (serverProtocols.contains(clientProtocol)) if (clientProtocols.contains(serverProtocol))
{ {
ConnectionFactory factory = getConnector().getConnectionFactory(clientProtocol); ConnectionFactory factory = getConnector().getConnectionFactory(serverProtocol);
if (factory instanceof CipherDiscriminator && !((CipherDiscriminator)factory).isAcceptable(serverProtocol, tlsProtocol, tlsCipher))
if (factory instanceof CipherDiscriminator && !((CipherDiscriminator)factory).isAcceptable(clientProtocol,tlsProtocol,tlsCipher))
{ {
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("{} protocol {} not acceptable to {} for {}/{}", this, clientProtocol,factory,tlsProtocol,tlsCipher); LOG.debug("{} protocol {} not acceptable to {} for {}/{}", this, serverProtocol, factory, tlsProtocol, tlsCipher);
continue; continue;
} }
negotiated = clientProtocol; negotiated = serverProtocol;
break; break;
} }
} }
if (negotiated == null) if (negotiated == null)
{ {
if (clientProtocols.isEmpty()) if (clientProtocols.isEmpty())
negotiated=getDefaultProtocol(); {
negotiated = getDefaultProtocol();
}
else else
throw new IllegalStateException("No acceptable protocol"); {
if (LOG.isDebugEnabled())
LOG.debug("{} could not negotiate protocol: C[{}] | S[{}]", this, clientProtocols, serverProtocols);
throw new IllegalStateException();
}
} }
if (LOG.isDebugEnabled()) if (LOG.isDebugEnabled())
LOG.debug("{} protocol selected {}", this, negotiated); LOG.debug("{} protocol selected {}", this, negotiated);

View File

@ -251,7 +251,7 @@ public abstract class HttpDestination implements Destination, Closeable, Dumpabl
@Override @Override
public void dump(Appendable out, String indent) throws IOException public void dump(Appendable out, String indent) throws IOException
{ {
ContainerLifeCycle.dumpObject(out, this + " - requests queued: " + exchanges.size()); ContainerLifeCycle.dumpObject(out, toString());
} }
public String asString() public String asString()
@ -262,9 +262,10 @@ public abstract class HttpDestination implements Destination, Closeable, Dumpabl
@Override @Override
public String toString() public String toString()
{ {
return String.format("%s[%s]%s,queue=%d", return String.format("%s[%s]%x%s,queue=%d",
HttpDestination.class.getSimpleName(), HttpDestination.class.getSimpleName(),
asString(), asString(),
hashCode(),
proxy == null ? "" : "(via " + proxy + ")", proxy == null ? "" : "(via " + proxy + ")",
exchanges.size()); exchanges.size());
} }

View File

@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Log;
@ -60,7 +59,7 @@ public class HttpExchange
return request.getConversation(); return request.getConversation();
} }
public Request getRequest() public HttpRequest getRequest()
{ {
return request; return request;
} }

View File

@ -160,35 +160,27 @@ public abstract class HttpSender implements AsyncContentProvider.Listener
public void send(HttpExchange exchange) public void send(HttpExchange exchange)
{ {
Request request = exchange.getRequest(); Request request = exchange.getRequest();
Throwable cause = request.getAbortCause(); if (!queuedToBegin(request))
if (cause != null) return;
{
exchange.abort(cause);
}
else
{
if (!queuedToBegin(request))
throw new IllegalStateException();
ContentProvider contentProvider = request.getContent(); ContentProvider contentProvider = request.getContent();
HttpContent content = this.content = new HttpContent(contentProvider); HttpContent content = this.content = new HttpContent(contentProvider);
SenderState newSenderState = SenderState.SENDING; SenderState newSenderState = SenderState.SENDING;
if (expects100Continue(request)) if (expects100Continue(request))
newSenderState = content.hasContent() ? SenderState.EXPECTING_WITH_CONTENT : SenderState.EXPECTING; newSenderState = content.hasContent() ? SenderState.EXPECTING_WITH_CONTENT : SenderState.EXPECTING;
if (!updateSenderState(SenderState.IDLE, newSenderState)) if (!updateSenderState(SenderState.IDLE, newSenderState))
throw illegalSenderState(SenderState.IDLE); throw illegalSenderState(SenderState.IDLE);
// Setting the listener may trigger calls to onContent() by other // Setting the listener may trigger calls to onContent() by other
// threads so we must set it only after the sender state has been updated // threads so we must set it only after the sender state has been updated
if (contentProvider instanceof AsyncContentProvider) if (contentProvider instanceof AsyncContentProvider)
((AsyncContentProvider)contentProvider).setListener(this); ((AsyncContentProvider)contentProvider).setListener(this);
if (!beginToHeaders(request)) if (!beginToHeaders(request))
return; return;
sendHeaders(exchange, content, commitCallback); sendHeaders(exchange, content, commitCallback);
}
} }
protected boolean expects100Continue(Request request) protected boolean expects100Continue(Request request)

View File

@ -211,6 +211,7 @@ public abstract class PoolingHttpDestination<C extends Connection> extends HttpD
@Override @Override
public void dump(Appendable out, String indent) throws IOException public void dump(Appendable out, String indent) throws IOException
{ {
super.dump(out, indent);
ContainerLifeCycle.dump(out, indent, Arrays.asList(connectionPool)); ContainerLifeCycle.dump(out, indent, Arrays.asList(connectionPool));
} }

View File

@ -399,10 +399,12 @@ public interface Request
ContentResponse send() throws InterruptedException, TimeoutException, ExecutionException; ContentResponse send() throws InterruptedException, TimeoutException, ExecutionException;
/** /**
* Sends this request and asynchronously notifies the given listener for response events. * <p>Sends this request and asynchronously notifies the given listener for response events.</p>
* <p /> * <p>This method should be used when the application needs to be notified of the various response events
* This method should be used when the application needs to be notified of the various response events * as they happen, or when the application needs to efficiently manage the response content.</p>
* as they happen, or when the application needs to efficiently manage the response content. * <p>The listener passed to this method may implement not only {@link Response.CompleteListener}
* but also other response listener interfaces, and all the events implemented will be notified.
* This allows application code to write a single listener class to handle all relevant events.</p>
* *
* @param listener the listener that receives response events * @param listener the listener that receives response events
*/ */

View File

@ -94,6 +94,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
private final AtomicReference<Listener> listener = new AtomicReference<>(); private final AtomicReference<Listener> listener = new AtomicReference<>();
private final DeferredContentProviderIterator iterator = new DeferredContentProviderIterator(); private final DeferredContentProviderIterator iterator = new DeferredContentProviderIterator();
private final AtomicBoolean closed = new AtomicBoolean(); private final AtomicBoolean closed = new AtomicBoolean();
private long length = -1;
private int size; private int size;
private Throwable failure; private Throwable failure;
@ -114,12 +115,23 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
if (!this.listener.compareAndSet(null, listener)) if (!this.listener.compareAndSet(null, listener))
throw new IllegalStateException(String.format("The same %s instance cannot be used in multiple requests", throw new IllegalStateException(String.format("The same %s instance cannot be used in multiple requests",
AsyncContentProvider.class.getName())); AsyncContentProvider.class.getName()));
if (isClosed())
{
synchronized (lock)
{
long total = 0;
for (AsyncChunk chunk : chunks)
total += chunk.buffer.remaining();
length = total;
}
}
} }
@Override @Override
public long getLength() public long getLength()
{ {
return -1; return length;
} }
/** /**
@ -200,6 +212,11 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback,
offer(CLOSE); offer(CLOSE);
} }
public boolean isClosed()
{
return closed.get();
}
@Override @Override
public void succeeded() public void succeeded()
{ {

View File

@ -81,6 +81,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import static java.nio.file.StandardOpenOption.CREATE; import static java.nio.file.StandardOpenOption.CREATE;
import static org.junit.Assert.assertTrue;
public class HttpClientTest extends AbstractHttpClientServerTest public class HttpClientTest extends AbstractHttpClientServerTest
{ {
@ -766,6 +767,37 @@ public class HttpClientTest extends AbstractHttpClientServerTest
Assert.assertFalse(response.getHeaders().containsKey(headerName)); Assert.assertFalse(response.getHeaders().containsKey(headerName));
} }
@Test
public void testAllHeadersDiscarded() throws Exception
{
start(new EmptyServerHandler());
int count = 10;
final CountDownLatch latch = new CountDownLatch(count);
for (int i = 0; i < count; ++i)
{
client.newRequest("localhost", connector.getLocalPort())
.scheme(scheme)
.send(new Response.Listener.Adapter()
{
@Override
public boolean onHeader(Response response, HttpField field)
{
return false;
}
@Override
public void onComplete(Result result)
{
if (result.isSucceeded())
latch.countDown();
}
});
}
assertTrue(latch.await(10, TimeUnit.SECONDS));
}
@Test @Test
public void test_HEAD_With_ResponseContentLength() throws Exception public void test_HEAD_With_ResponseContentLength() throws Exception
{ {

View File

@ -0,0 +1,610 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.proxy;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeoutException;
import javax.servlet.AsyncContext;
import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.util.HttpCookieStore;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
public abstract class AbstractProxyServlet extends HttpServlet
{
protected static final Set<String> HOP_HEADERS;
static
{
Set<String> hopHeaders = new HashSet<>();
hopHeaders.add("connection");
hopHeaders.add("keep-alive");
hopHeaders.add("proxy-authorization");
hopHeaders.add("proxy-authenticate");
hopHeaders.add("proxy-connection");
hopHeaders.add("transfer-encoding");
hopHeaders.add("te");
hopHeaders.add("trailer");
hopHeaders.add("upgrade");
HOP_HEADERS = Collections.unmodifiableSet(hopHeaders);
}
private final Set<String> _whiteList = new HashSet<>();
private final Set<String> _blackList = new HashSet<>();
protected Logger _log;
private String _hostHeader;
private String _viaHost;
private HttpClient _client;
private long _timeout;
@Override
public void init() throws ServletException
{
_log = createLogger();
ServletConfig config = getServletConfig();
_hostHeader = config.getInitParameter("hostHeader");
_viaHost = config.getInitParameter("viaHost");
if (_viaHost == null)
_viaHost = viaHost();
try
{
_client = createHttpClient();
// Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES
getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client);
String whiteList = config.getInitParameter("whiteList");
if (whiteList != null)
getWhiteListHosts().addAll(parseList(whiteList));
String blackList = config.getInitParameter("blackList");
if (blackList != null)
getBlackListHosts().addAll(parseList(blackList));
}
catch (Exception e)
{
throw new ServletException(e);
}
}
@Override
public void destroy()
{
try
{
_client.stop();
}
catch (Exception x)
{
if (_log.isDebugEnabled())
_log.debug(x);
}
}
public String getHostHeader()
{
return _hostHeader;
}
public String getViaHost()
{
return _viaHost;
}
private static String viaHost()
{
try
{
return InetAddress.getLocalHost().getHostName();
}
catch (UnknownHostException x)
{
return "localhost";
}
}
public long getTimeout()
{
return _timeout;
}
public void setTimeout(long timeout)
{
this._timeout = timeout;
}
public Set<String> getWhiteListHosts()
{
return _whiteList;
}
public Set<String> getBlackListHosts()
{
return _blackList;
}
/**
* @return a logger instance with a name derived from this servlet's name.
*/
protected Logger createLogger()
{
String servletName = getServletConfig().getServletName();
servletName = servletName.replace('-', '.');
if ((getClass().getPackage() != null) && !servletName.startsWith(getClass().getPackage().getName()))
{
servletName = getClass().getName() + "." + servletName;
}
return Log.getLogger(servletName);
}
/**
* Creates a {@link HttpClient} instance, configured with init parameters of this servlet.
* <p/>
* The init parameters used to configure the {@link HttpClient} instance are:
* <table>
* <thead>
* <tr>
* <th>init-param</th>
* <th>default</th>
* <th>description</th>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>maxThreads</td>
* <td>256</td>
* <td>The max number of threads of HttpClient's Executor. If not set, or set to the value of "-", then the
* Jetty server thread pool will be used.</td>
* </tr>
* <tr>
* <td>maxConnections</td>
* <td>32768</td>
* <td>The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}</td>
* </tr>
* <tr>
* <td>idleTimeout</td>
* <td>30000</td>
* <td>The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}</td>
* </tr>
* <tr>
* <td>timeout</td>
* <td>60000</td>
* <td>The total timeout in milliseconds, see {@link Request#timeout(long, java.util.concurrent.TimeUnit)}</td>
* </tr>
* <tr>
* <td>requestBufferSize</td>
* <td>HttpClient's default</td>
* <td>The request buffer size, see {@link HttpClient#setRequestBufferSize(int)}</td>
* </tr>
* <tr>
* <td>responseBufferSize</td>
* <td>HttpClient's default</td>
* <td>The response buffer size, see {@link HttpClient#setResponseBufferSize(int)}</td>
* </tr>
* </tbody>
* </table>
*
* @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration}
* @throws ServletException if the {@link HttpClient} cannot be created
*/
protected HttpClient createHttpClient() throws ServletException
{
ServletConfig config = getServletConfig();
HttpClient client = newHttpClient();
// Redirects must be proxied as is, not followed
client.setFollowRedirects(false);
// Must not store cookies, otherwise cookies of different clients will mix
client.setCookieStore(new HttpCookieStore.Empty());
Executor executor;
String value = config.getInitParameter("maxThreads");
if (value == null || "-".equals(value))
{
executor = (Executor)getServletContext().getAttribute("org.eclipse.jetty.server.Executor");
if (executor==null)
throw new IllegalStateException("No server executor for proxy");
}
else
{
QueuedThreadPool qtp= new QueuedThreadPool(Integer.parseInt(value));
String servletName = config.getServletName();
int dot = servletName.lastIndexOf('.');
if (dot >= 0)
servletName = servletName.substring(dot + 1);
qtp.setName(servletName);
executor=qtp;
}
client.setExecutor(executor);
value = config.getInitParameter("maxConnections");
if (value == null)
value = "256";
client.setMaxConnectionsPerDestination(Integer.parseInt(value));
value = config.getInitParameter("idleTimeout");
if (value == null)
value = "30000";
client.setIdleTimeout(Long.parseLong(value));
value = config.getInitParameter("timeout");
if (value == null)
value = "60000";
_timeout = Long.parseLong(value);
value = config.getInitParameter("requestBufferSize");
if (value != null)
client.setRequestBufferSize(Integer.parseInt(value));
value = config.getInitParameter("responseBufferSize");
if (value != null)
client.setResponseBufferSize(Integer.parseInt(value));
try
{
client.start();
// Content must not be decoded, otherwise the client gets confused
client.getContentDecoderFactories().clear();
return client;
}
catch (Exception x)
{
throw new ServletException(x);
}
}
/**
* @return a new HttpClient instance
*/
protected HttpClient newHttpClient()
{
return new HttpClient();
}
protected HttpClient getHttpClient()
{
return _client;
}
private Set<String> parseList(String list)
{
Set<String> result = new HashSet<>();
String[] hosts = list.split(",");
for (String host : hosts)
{
host = host.trim();
if (host.length() == 0)
continue;
result.add(host);
}
return result;
}
/**
* Checks the given {@code host} and {@code port} against whitelist and blacklist.
*
* @param host the host to check
* @param port the port to check
* @return true if it is allowed to be proxy to the given host and port
*/
public boolean validateDestination(String host, int port)
{
String hostPort = host + ":" + port;
if (!_whiteList.isEmpty())
{
if (!_whiteList.contains(hostPort))
{
if (_log.isDebugEnabled())
_log.debug("Host {}:{} not whitelisted", host, port);
return false;
}
}
if (!_blackList.isEmpty())
{
if (_blackList.contains(hostPort))
{
if (_log.isDebugEnabled())
_log.debug("Host {}:{} blacklisted", host, port);
return false;
}
}
return true;
}
protected String rewriteTarget(HttpServletRequest clientRequest)
{
if (!validateDestination(clientRequest.getServerName(), clientRequest.getServerPort()))
return null;
StringBuffer target = clientRequest.getRequestURL();
String query = clientRequest.getQueryString();
if (query != null)
target.append("?").append(query);
return target.toString();
}
/**
* <p>Callback method invoked when the URI rewrite performed
* in {@link #rewriteTarget(HttpServletRequest)} returns null
* indicating that no rewrite can be performed.</p>
* <p>It is possible to use blocking API in this method,
* like {@link HttpServletResponse#sendError(int)}.</p>
*
* @param clientRequest the client request
* @param clientResponse the client response
*/
protected void onProxyRewriteFailed(HttpServletRequest clientRequest, HttpServletResponse clientResponse)
{
clientResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
}
protected boolean hasContent(HttpServletRequest clientRequest)
{
return clientRequest.getContentLength() > 0 ||
clientRequest.getContentType() != null ||
clientRequest.getHeader(HttpHeader.TRANSFER_ENCODING.asString()) != null;
}
protected void copyHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
Set<String> headersToRemove = findConnectionHeaders(clientRequest);
for (Enumeration<String> headerNames = clientRequest.getHeaderNames(); headerNames.hasMoreElements();)
{
String headerName = headerNames.nextElement();
String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
if (_hostHeader != null && HttpHeader.HOST.is(headerName))
continue;
// Remove hop-by-hop headers.
if (HOP_HEADERS.contains(lowerHeaderName))
continue;
if (headersToRemove != null && headersToRemove.contains(lowerHeaderName))
continue;
for (Enumeration<String> headerValues = clientRequest.getHeaders(headerName); headerValues.hasMoreElements();)
{
String headerValue = headerValues.nextElement();
if (headerValue != null)
proxyRequest.header(headerName, headerValue);
}
}
// Force the Host header if configured
if (_hostHeader != null)
proxyRequest.header(HttpHeader.HOST, _hostHeader);
}
protected Set<String> findConnectionHeaders(HttpServletRequest clientRequest)
{
// Any header listed by the Connection header must be removed:
// http://tools.ietf.org/html/rfc7230#section-6.1.
Set<String> hopHeaders = null;
Enumeration<String> connectionHeaders = clientRequest.getHeaders(HttpHeader.CONNECTION.asString());
while (connectionHeaders.hasMoreElements())
{
String value = connectionHeaders.nextElement();
String[] values = value.split(",");
for (String name : values)
{
name = name.trim().toLowerCase(Locale.ENGLISH);
if (hopHeaders == null)
hopHeaders = new HashSet<>();
hopHeaders.add(name);
}
}
return hopHeaders;
}
protected void addProxyHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
addViaHeader(proxyRequest);
addXForwardedHeaders(clientRequest, proxyRequest);
}
protected void addViaHeader(Request proxyRequest)
{
proxyRequest.header(HttpHeader.VIA, "http/1.1 " + getViaHost());
}
protected void addXForwardedHeaders(HttpServletRequest clientRequest, Request proxyRequest)
{
proxyRequest.header(HttpHeader.X_FORWARDED_FOR, clientRequest.getRemoteAddr());
proxyRequest.header(HttpHeader.X_FORWARDED_PROTO, clientRequest.getScheme());
proxyRequest.header(HttpHeader.X_FORWARDED_HOST, clientRequest.getHeader(HttpHeader.HOST.asString()));
proxyRequest.header(HttpHeader.X_FORWARDED_SERVER, clientRequest.getLocalName());
}
protected void sendProxyRequest(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest)
{
if (_log.isDebugEnabled())
{
StringBuilder builder = new StringBuilder(clientRequest.getMethod());
builder.append(" ").append(clientRequest.getRequestURI());
String query = clientRequest.getQueryString();
if (query != null)
builder.append("?").append(query);
builder.append(" ").append(clientRequest.getProtocol()).append(System.lineSeparator());
for (Enumeration<String> headerNames = clientRequest.getHeaderNames(); headerNames.hasMoreElements();)
{
String headerName = headerNames.nextElement();
builder.append(headerName).append(": ");
for (Enumeration<String> headerValues = clientRequest.getHeaders(headerName); headerValues.hasMoreElements();)
{
String headerValue = headerValues.nextElement();
if (headerValue != null)
builder.append(headerValue);
if (headerValues.hasMoreElements())
builder.append(",");
}
builder.append(System.lineSeparator());
}
builder.append(System.lineSeparator());
_log.debug("{} proxying to upstream:{}{}{}{}",
getRequestId(clientRequest),
System.lineSeparator(),
builder,
proxyRequest,
System.lineSeparator(),
proxyRequest.getHeaders().toString().trim());
}
proxyRequest.send(newProxyResponseListener(clientRequest, proxyResponse));
}
protected abstract Response.CompleteListener newProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse);
protected void onClientRequestFailure(HttpServletRequest clientRequest, Request proxyRequest, HttpServletResponse proxyResponse, Throwable failure)
{
boolean aborted = proxyRequest.abort(failure);
if (!aborted)
{
proxyResponse.setStatus(500);
clientRequest.getAsyncContext().complete();
}
}
protected void onServerResponseHeaders(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
for (HttpField field : serverResponse.getHeaders())
{
String headerName = field.getName();
String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
if (HOP_HEADERS.contains(lowerHeaderName))
continue;
String newHeaderValue = filterServerResponseHeader(clientRequest, headerName, field.getValue());
if (newHeaderValue == null || newHeaderValue.trim().length() == 0)
continue;
proxyResponse.addHeader(headerName, newHeaderValue);
}
if (_log.isDebugEnabled())
{
StringBuilder builder = new StringBuilder(System.lineSeparator());
builder.append(clientRequest.getProtocol()).append(" ").append(proxyResponse.getStatus())
.append(" ").append(serverResponse.getReason()).append(System.lineSeparator());
for (String headerName : proxyResponse.getHeaderNames())
{
builder.append(headerName).append(": ");
for (Iterator<String> headerValues = proxyResponse.getHeaders(headerName).iterator(); headerValues.hasNext(); )
{
String headerValue = headerValues.next();
if (headerValue != null)
builder.append(headerValue);
if (headerValues.hasNext())
builder.append(",");
}
builder.append(System.lineSeparator());
}
_log.debug("{} proxying to downstream:{}{}{}{}{}",
getRequestId(clientRequest),
System.lineSeparator(),
serverResponse,
System.lineSeparator(),
serverResponse.getHeaders().toString().trim(),
System.lineSeparator(),
builder);
}
}
protected String filterServerResponseHeader(HttpServletRequest clientRequest, String headerName, String headerValue)
{
return headerValue;
}
protected void onProxyResponseSuccess(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
if (_log.isDebugEnabled())
_log.debug("{} proxying successful", getRequestId(clientRequest));
AsyncContext asyncContext = clientRequest.getAsyncContext();
asyncContext.complete();
}
protected void onProxyResponseFailure(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse, Throwable failure)
{
if (_log.isDebugEnabled())
_log.debug(getRequestId(clientRequest) + " proxying failed", failure);
if (proxyResponse.isCommitted())
{
try
{
// Use Jetty specific behavior to close connection.
proxyResponse.sendError(-1);
AsyncContext asyncContext = clientRequest.getAsyncContext();
asyncContext.complete();
}
catch (IOException x)
{
if (_log.isDebugEnabled())
_log.debug(getRequestId(clientRequest) + " could not close the connection", failure);
}
}
else
{
proxyResponse.resetBuffer();
if (failure instanceof TimeoutException)
proxyResponse.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT);
else
proxyResponse.setStatus(HttpServletResponse.SC_BAD_GATEWAY);
proxyResponse.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
AsyncContext asyncContext = clientRequest.getAsyncContext();
asyncContext.complete();
}
}
protected int getRequestId(HttpServletRequest clientRequest)
{
return System.identityHashCode(clientRequest);
}
}

View File

@ -0,0 +1,700 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.proxy;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.zip.GZIPOutputStream;
import javax.servlet.AsyncContext;
import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.ContentDecoder;
import org.eclipse.jetty.client.GZIPContentDecoder;
import org.eclipse.jetty.client.api.ContentProvider;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.client.util.DeferredContentProvider;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.io.RuntimeIOException;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.IteratingCallback;
public class AsyncMiddleManServlet extends AbstractProxyServlet
{
private static final String CLIENT_TRANSFORMER = AsyncMiddleManServlet.class.getName() + ".clientTransformer";
private static final String SERVER_TRANSFORMER = AsyncMiddleManServlet.class.getName() + ".serverTransformer";
@Override
protected void service(HttpServletRequest clientRequest, HttpServletResponse proxyResponse) throws ServletException, IOException
{
String rewrittenTarget = rewriteTarget(clientRequest);
if (_log.isDebugEnabled())
{
StringBuffer target = clientRequest.getRequestURL();
if (clientRequest.getQueryString() != null)
target.append("?").append(clientRequest.getQueryString());
_log.debug("{} rewriting: {} -> {}", getRequestId(clientRequest), target, rewrittenTarget);
}
if (rewrittenTarget == null)
{
onProxyRewriteFailed(clientRequest, proxyResponse);
return;
}
final Request proxyRequest = getHttpClient().newRequest(rewrittenTarget)
.method(clientRequest.getMethod())
.version(HttpVersion.fromString(clientRequest.getProtocol()));
boolean hasContent = hasContent(clientRequest);
copyHeaders(clientRequest, proxyRequest);
addProxyHeaders(clientRequest, proxyRequest);
final AsyncContext asyncContext = clientRequest.startAsync();
// We do not timeout the continuation, but the proxy request.
asyncContext.setTimeout(0);
proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS);
// If there is content, the send of the proxy request
// is delayed and performed when the content arrives,
// to allow optimization of the Content-Length header.
if (hasContent)
proxyRequest.content(newProxyContentProvider(clientRequest, proxyResponse, proxyRequest));
else
sendProxyRequest(clientRequest, proxyResponse, proxyRequest);
}
protected ContentProvider newProxyContentProvider(final HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest) throws IOException
{
ServletInputStream input = clientRequest.getInputStream();
DeferredContentProvider provider = new DeferredContentProvider()
{
@Override
public boolean offer(ByteBuffer buffer, Callback callback)
{
if (_log.isDebugEnabled())
_log.debug("{} proxying content to upstream: {} bytes", getRequestId(clientRequest), buffer.remaining());
return super.offer(buffer, callback);
}
};
input.setReadListener(newProxyReadListener(clientRequest, proxyResponse, proxyRequest, provider));
return provider;
}
protected ReadListener newProxyReadListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest, DeferredContentProvider provider)
{
return new ProxyReader(clientRequest, proxyResponse, proxyRequest, provider);
}
protected ProxyWriter newProxyWriteListener(HttpServletRequest clientRequest, Response proxyResponse)
{
return new ProxyWriter(clientRequest, proxyResponse);
}
protected Response.CompleteListener newProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse)
{
return new ProxyResponseListener(clientRequest, proxyResponse);
}
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return ContentTransformer.IDENTITY;
}
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return ContentTransformer.IDENTITY;
}
int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException
{
return input.read(buffer);
}
void writeProxyResponseContent(ServletOutputStream output, ByteBuffer content) throws IOException
{
write(output, content);
}
private static void write(OutputStream output, ByteBuffer content) throws IOException
{
int length = content.remaining();
int offset = 0;
byte[] buffer;
if (content.hasArray())
{
offset = content.arrayOffset();
buffer = content.array();
}
else
{
buffer = new byte[length];
content.get(buffer);
}
output.write(buffer, offset, length);
}
protected class ProxyReader extends IteratingCallback implements ReadListener
{
private final Callback failer = new Adapter()
{
@Override
public void failed(Throwable x)
{
onError(x);
}
};
private final byte[] buffer = new byte[getHttpClient().getRequestBufferSize()];
private final List<ByteBuffer> buffers = new ArrayList<>();
private final HttpServletRequest clientRequest;
private final HttpServletResponse proxyResponse;
private final Request proxyRequest;
private final DeferredContentProvider provider;
private final int contentLength;
private int length;
protected ProxyReader(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest, DeferredContentProvider provider)
{
this.clientRequest = clientRequest;
this.proxyResponse = proxyResponse;
this.proxyRequest = proxyRequest;
this.provider = provider;
this.contentLength = clientRequest.getContentLength();
}
@Override
public void onDataAvailable() throws IOException
{
iterate();
}
@Override
public void onAllDataRead() throws IOException
{
if (!provider.isClosed())
process(BufferUtil.EMPTY_BUFFER, failer, true);
if (_log.isDebugEnabled())
_log.debug("{} proxying content to upstream completed", getRequestId(clientRequest));
}
@Override
public void onError(Throwable t)
{
onClientRequestFailure(clientRequest, proxyRequest, proxyResponse, t);
}
@Override
protected Action process() throws Exception
{
ServletInputStream input = clientRequest.getInputStream();
while (input.isReady() && !input.isFinished())
{
int read = readClientRequestContent(input, buffer);
if (_log.isDebugEnabled())
_log.debug("{} asynchronous read {} bytes on {}", getRequestId(clientRequest), read, input);
if (contentLength > 0 && read > 0)
length += read;
ByteBuffer content = read > 0 ? ByteBuffer.wrap(buffer, 0, read) : BufferUtil.EMPTY_BUFFER;
boolean finished = read < 0 || length == contentLength;
process(content, this, finished);
if (read > 0)
return Action.SCHEDULED;
}
if (input.isFinished())
{
if (_log.isDebugEnabled())
_log.debug("{} asynchronous read complete on {}", getRequestId(clientRequest), input);
return Action.SUCCEEDED;
}
else
{
if (_log.isDebugEnabled())
_log.debug("{} asynchronous read pending on {}", getRequestId(clientRequest), input);
return Action.IDLE;
}
}
private void process(ByteBuffer content, Callback callback, boolean finished) throws IOException
{
ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(CLIENT_TRANSFORMER);
boolean committed = transformer != null;
if (transformer == null)
{
transformer = newClientRequestContentTransformer(clientRequest, proxyRequest);
clientRequest.setAttribute(CLIENT_TRANSFORMER, transformer);
}
if (content.hasRemaining() || finished)
{
int contentBytes = content.remaining();
transformer.transform(content, finished, buffers);
int newContentBytes = 0;
int size = buffers.size();
for (int i = 0; i < size; ++i)
{
ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining();
provider.offer(buffer, i == size - 1 ? callback : failer);
}
buffers.clear();
if (finished)
provider.close();
if (_log.isDebugEnabled())
_log.debug("{} upstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes);
if (!committed)
{
proxyRequest.header(HttpHeader.CONTENT_LENGTH, null);
sendProxyRequest(clientRequest, proxyResponse, proxyRequest);
}
if (size == 0)
succeeded();
}
}
@Override
protected void onCompleteFailure(Throwable x)
{
onError(x);
}
}
protected class ProxyResponseListener extends Response.Listener.Adapter
{
private final String WRITE_LISTENER_ATTRIBUTE = AsyncMiddleManServlet.class.getName() + ".writeListener";
private final List<ByteBuffer> buffers = new ArrayList<>();
private final AtomicBoolean complete = new AtomicBoolean();
private final HttpServletRequest clientRequest;
private final HttpServletResponse proxyResponse;
private boolean hasContent;
private long contentLength;
private long length;
protected ProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse)
{
this.clientRequest = clientRequest;
this.proxyResponse = proxyResponse;
}
@Override
public void onBegin(Response serverResponse)
{
proxyResponse.setStatus(serverResponse.getStatus());
}
@Override
public void onHeaders(Response serverResponse)
{
contentLength = serverResponse.getHeaders().getLongField(HttpHeader.CONTENT_LENGTH.asString());
onServerResponseHeaders(clientRequest, proxyResponse, serverResponse);
}
@Override
public void onContent(final Response serverResponse, ByteBuffer content, final Callback callback)
{
try
{
int contentBytes = content.remaining();
if (_log.isDebugEnabled())
_log.debug("{} received server content: {} bytes", getRequestId(clientRequest), contentBytes);
hasContent = true;
ProxyWriter proxyWriter = (ProxyWriter)clientRequest.getAttribute(WRITE_LISTENER_ATTRIBUTE);
boolean committed = proxyWriter != null;
if (proxyWriter == null)
{
proxyWriter = newProxyWriteListener(clientRequest, serverResponse);
clientRequest.setAttribute(WRITE_LISTENER_ATTRIBUTE, proxyWriter);
}
ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(SERVER_TRANSFORMER);
if (transformer == null)
{
transformer = newServerResponseContentTransformer(clientRequest, proxyResponse, serverResponse);
clientRequest.setAttribute(SERVER_TRANSFORMER, transformer);
}
length += contentBytes;
boolean finished = contentLength > 0 && length == contentLength;
transformer.transform(content, finished, buffers);
int newContentBytes = 0;
int size = buffers.size();
for (int i = 0; i < size; ++i)
{
ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining();
proxyWriter.offer(buffer, i == size - 1 ? callback : Callback.Adapter.INSTANCE);
}
buffers.clear();
if (_log.isDebugEnabled())
_log.debug("{} downstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes);
if (committed)
{
proxyWriter.onWritePossible();
}
else
{
if (contentLength > 0)
proxyResponse.setContentLength(-1);
// Setting the WriteListener triggers an invocation to
// onWritePossible(), possibly on a different thread.
proxyResponse.getOutputStream().setWriteListener(proxyWriter);
}
if (size == 0)
callback.succeeded();
}
catch (Throwable x)
{
callback.failed(x);
}
}
@Override
public void onSuccess(final Response serverResponse)
{
try
{
// If we had unknown length content, we need to call the
// transformer to signal that the content is finished.
if (contentLength < 0 && hasContent)
{
ProxyWriter proxyWriter = (ProxyWriter)clientRequest.getAttribute(WRITE_LISTENER_ATTRIBUTE);
ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(SERVER_TRANSFORMER);
transformer.transform(BufferUtil.EMPTY_BUFFER, true, buffers);
long newContentBytes = 0;
int size = buffers.size();
for (int i = 0; i < size; ++i)
{
ByteBuffer buffer = buffers.get(i);
newContentBytes += buffer.remaining();
proxyWriter.offer(buffer, i == size - 1 ? new Callback.Adapter()
{
@Override
public void failed(Throwable x)
{
if (complete.compareAndSet(false, true))
onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x);
}
} : Callback.Adapter.INSTANCE);
}
buffers.clear();
if (_log.isDebugEnabled())
_log.debug("{} downstream content transformation to {} bytes", getRequestId(clientRequest), newContentBytes);
proxyWriter.onWritePossible();
}
}
catch (Throwable x)
{
if (complete.compareAndSet(false, true))
onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x);
}
}
@Override
public void onComplete(Result result)
{
if (complete.compareAndSet(false, true))
{
if (result.isSucceeded())
onProxyResponseSuccess(clientRequest, proxyResponse, result.getResponse());
else
onProxyResponseFailure(clientRequest, proxyResponse, result.getResponse(), result.getFailure());
}
if (_log.isDebugEnabled())
_log.debug("{} proxying complete", getRequestId(clientRequest));
}
}
protected class ProxyWriter implements WriteListener
{
private final Queue<AsyncChunk> chunks = new ArrayDeque<>();
private final HttpServletRequest clientRequest;
private final Response serverResponse;
private AsyncChunk chunk;
private boolean writePending;
protected ProxyWriter(HttpServletRequest clientRequest, Response serverResponse)
{
this.clientRequest = clientRequest;
this.serverResponse = serverResponse;
}
public boolean offer(ByteBuffer content, Callback callback)
{
if (_log.isDebugEnabled())
_log.debug("{} proxying content to downstream: {} bytes", getRequestId(clientRequest), content.remaining());
return chunks.offer(new AsyncChunk(content, callback));
}
@Override
public void onWritePossible() throws IOException
{
ServletOutputStream output = clientRequest.getAsyncContext().getResponse().getOutputStream();
while (true)
{
if (writePending)
{
// The write was pending but is now complete.
writePending = false;
if (_log.isDebugEnabled())
_log.debug("{} pending async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output);
if (succeed(chunk.callback))
break;
}
else
{
chunk = chunks.poll();
if (chunk == null)
break;
writeProxyResponseContent(output, chunk.buffer);
if (output.isReady())
{
if (_log.isDebugEnabled())
_log.debug("{} async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output);
if (succeed(chunk.callback))
break;
}
else
{
writePending = true;
if (_log.isDebugEnabled())
_log.debug("{} async write pending of {} bytes on {}", getRequestId(clientRequest), chunk.length, output);
break;
}
}
}
}
private boolean succeed(Callback callback)
{
// Succeeding the callback may cause to reenter in onWritePossible()
// because typically the callback is the one that controls whether the
// content received from the server has been consumed, so succeeding
// the callback causes more content to be received from the server,
// and hence more to be written to the client by onWritePossible().
// A reentrant call to onWritePossible() performs another write,
// which may remain pending, which means that the reentrant call
// to onWritePossible() returns all the way back to just after the
// succeed of the callback. There, we cannot just loop attempting
// write, but we need to check whether we are still write pending.
callback.succeeded();
return writePending;
}
@Override
public void onError(Throwable failure)
{
AsyncChunk chunk = this.chunk;
if (chunk != null)
chunk.callback.failed(failure);
else
serverResponse.abort(failure);
}
}
// TODO: coalesce this class with the one from DeferredContentProvider ?
private static class AsyncChunk
{
private final ByteBuffer buffer;
private final Callback callback;
private final int length;
private AsyncChunk(ByteBuffer buffer, Callback callback)
{
this.buffer = Objects.requireNonNull(buffer);
this.callback = Objects.requireNonNull(callback);
this.length = buffer.remaining();
}
}
/**
* <p>Allows applications to transform upstream and downstream content.</p>
* <p>Typical use cases of transformations are URL rewriting of HTML anchors
* (where the value of the <code>href</code> attribute of &lt;a&gt; elements
* is modified by the proxy), field renaming of JSON documents, etc.</p>
* <p>Applications should override {@link #newClientRequestContentTransformer(HttpServletRequest, Request)}
* and/or {@link #newServerResponseContentTransformer(HttpServletRequest, HttpServletResponse, Response)}
* to provide the transformer implementation.</p>
*/
public interface ContentTransformer
{
/**
* The identity transformer that does not perform any transformation.
*/
public static final ContentTransformer IDENTITY = new IdentityContentTransformer();
/**
* <p>Transforms the given input byte buffers into (possibly multiple) byte buffers.</p>
* <p>The transformation must happen synchronously in the context of a call
* to this method (it is not supported to perform the transformation in another
* thread spawned during the call to this method).
* The transformation may happen or not, depending on the transformer implementation.
* For example, a buffering transformer may buffer the input aside, and only
* perform the transformation when the whole input is provided (by looking at the
* {@code finished} flag).</p>
* <p>The input buffer will be cleared and reused after the call to this method.
* Implementations that want to buffer aside the input (or part of it) must copy
* the input bytes that they want to buffer.</p>
* <p>Typical implementations:</p>
* <pre>
* // Identity transformation (no transformation, the input is copied to the output)
* public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output)
* {
* output.add(input);
* }
*
* // Discard transformation (all input is discarded)
* public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output)
* {
* // Empty
* }
*
* // Buffering identity transformation (all input is buffered aside until it is finished)
* public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output)
* {
* ByteBuffer copy = ByteBuffer.allocate(input.remaining());
* copy.put(input).flip();
* store(copy);
*
* if (finished)
* {
* List&lt;ByteBuffer&gt; copies = retrieve();
* output.addAll(copies);
* }
* }
* </pre>
*
* @param input the input content to transform (may be of length zero)
* @param finished whether the input content is finished or more will come
* @param output where to put the transformed output content
* @throws IOException in case of transformation failures
*/
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException;
}
private static class IdentityContentTransformer implements ContentTransformer
{
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output)
{
output.add(input);
}
}
public static class GZIPContentTransformer implements ContentTransformer
{
private final List<ByteBuffer> buffers = new ArrayList<>(2);
private final ContentDecoder decoder = new GZIPContentDecoder();
private final ContentTransformer transformer;
private final ByteArrayOutputStream out;
private final GZIPOutputStream gzipOut;
public GZIPContentTransformer(ContentTransformer transformer)
{
try
{
this.transformer = transformer;
this.out = new ByteArrayOutputStream();
this.gzipOut = new GZIPOutputStream(out);
}
catch (IOException x)
{
throw new RuntimeIOException(x);
}
}
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
if (!input.hasRemaining())
{
if (finished)
transformer.transform(input, true, buffers);
}
else
{
while (input.hasRemaining())
{
ByteBuffer decoded = decoder.decode(input);
if (decoded.hasRemaining())
transformer.transform(decoded, finished && !input.hasRemaining(), buffers);
}
}
if (!buffers.isEmpty())
{
ByteBuffer result = gzip(buffers, finished);
buffers.clear();
output.add(result);
}
}
private ByteBuffer gzip(List<ByteBuffer> buffers, boolean finished) throws IOException
{
for (ByteBuffer buffer : buffers)
write(gzipOut, buffer);
if (finished)
gzipOut.close();
byte[] gzipBytes = out.toByteArray();
out.reset();
return ByteBuffer.wrap(gzipBytes);
}
}
}

View File

@ -20,24 +20,14 @@ package org.eclipse.jetty.proxy;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.InetAddress;
import java.net.URI; import java.net.URI;
import java.net.UnknownHostException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.servlet.AsyncContext; import javax.servlet.AsyncContext;
import javax.servlet.ServletConfig; import javax.servlet.ServletConfig;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.UnavailableException; import javax.servlet.UnavailableException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
@ -47,15 +37,8 @@ import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.client.util.InputStreamContentProvider; import org.eclipse.jetty.client.util.InputStreamContentProvider;
import org.eclipse.jetty.http.HttpField;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpHeaderValue;
import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.HttpCookieStore;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
/** /**
* Asynchronous ProxyServlet. * Asynchronous ProxyServlet.
@ -80,308 +63,8 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool;
* *
* @see ConnectHandler * @see ConnectHandler
*/ */
public class ProxyServlet extends HttpServlet public class ProxyServlet extends AbstractProxyServlet
{ {
private static final Set<String> HOP_HEADERS = new HashSet<>();
static
{
HOP_HEADERS.add("connection");
HOP_HEADERS.add("keep-alive");
HOP_HEADERS.add("proxy-authorization");
HOP_HEADERS.add("proxy-authenticate");
HOP_HEADERS.add("proxy-connection");
HOP_HEADERS.add("transfer-encoding");
HOP_HEADERS.add("te");
HOP_HEADERS.add("trailer");
HOP_HEADERS.add("upgrade");
}
private final Set<String> _whiteList = new HashSet<>();
private final Set<String> _blackList = new HashSet<>();
protected Logger _log;
private String _hostHeader;
private String _viaHost;
private HttpClient _client;
private long _timeout;
@Override
public void init() throws ServletException
{
_log = createLogger();
ServletConfig config = getServletConfig();
_hostHeader = config.getInitParameter("hostHeader");
_viaHost = config.getInitParameter("viaHost");
if (_viaHost == null)
_viaHost = viaHost();
try
{
_client = createHttpClient();
// Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES
getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client);
String whiteList = config.getInitParameter("whiteList");
if (whiteList != null)
getWhiteListHosts().addAll(parseList(whiteList));
String blackList = config.getInitParameter("blackList");
if (blackList != null)
getBlackListHosts().addAll(parseList(blackList));
}
catch (Exception e)
{
throw new ServletException(e);
}
}
public String getViaHost()
{
return _viaHost;
}
public long getTimeout()
{
return _timeout;
}
public void setTimeout(long timeout)
{
this._timeout = timeout;
}
public Set<String> getWhiteListHosts()
{
return _whiteList;
}
public Set<String> getBlackListHosts()
{
return _blackList;
}
protected static String viaHost()
{
try
{
return InetAddress.getLocalHost().getHostName();
}
catch (UnknownHostException x)
{
return "localhost";
}
}
protected HttpClient getHttpClient()
{
return _client;
}
/**
* @return a logger instance with a name derived from this servlet's name.
*/
protected Logger createLogger()
{
String servletName = getServletConfig().getServletName();
servletName = servletName.replace('-', '.');
if ((getClass().getPackage() != null) && !servletName.startsWith(getClass().getPackage().getName()))
{
servletName = getClass().getName() + "." + servletName;
}
return Log.getLogger(servletName);
}
public void destroy()
{
try
{
_client.stop();
}
catch (Exception x)
{
if (_log.isDebugEnabled())
_log.debug(x);
}
}
/**
* Creates a {@link HttpClient} instance, configured with init parameters of this servlet.
* <p/>
* The init parameters used to configure the {@link HttpClient} instance are:
* <table>
* <thead>
* <tr>
* <th>init-param</th>
* <th>default</th>
* <th>description</th>
* </tr>
* </thead>
* <tbody>
* <tr>
* <td>maxThreads</td>
* <td>256</td>
* <td>The max number of threads of HttpClient's Executor. If not set, or set to the value of "-", then the
* Jetty server thread pool will be used.</td>
* </tr>
* <tr>
* <td>maxConnections</td>
* <td>32768</td>
* <td>The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}</td>
* </tr>
* <tr>
* <td>idleTimeout</td>
* <td>30000</td>
* <td>The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}</td>
* </tr>
* <tr>
* <td>timeout</td>
* <td>60000</td>
* <td>The total timeout in milliseconds, see {@link Request#timeout(long, TimeUnit)}</td>
* </tr>
* <tr>
* <td>requestBufferSize</td>
* <td>HttpClient's default</td>
* <td>The request buffer size, see {@link HttpClient#setRequestBufferSize(int)}</td>
* </tr>
* <tr>
* <td>responseBufferSize</td>
* <td>HttpClient's default</td>
* <td>The response buffer size, see {@link HttpClient#setResponseBufferSize(int)}</td>
* </tr>
* </tbody>
* </table>
*
* @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration}
* @throws ServletException if the {@link HttpClient} cannot be created
*/
protected HttpClient createHttpClient() throws ServletException
{
ServletConfig config = getServletConfig();
HttpClient client = newHttpClient();
// Redirects must be proxied as is, not followed
client.setFollowRedirects(false);
// Must not store cookies, otherwise cookies of different clients will mix
client.setCookieStore(new HttpCookieStore.Empty());
Executor executor;
String value = config.getInitParameter("maxThreads");
if (value == null || "-".equals(value))
{
executor = (Executor)getServletContext().getAttribute("org.eclipse.jetty.server.Executor");
if (executor==null)
throw new IllegalStateException("No server executor for proxy");
}
else
{
QueuedThreadPool qtp= new QueuedThreadPool(Integer.parseInt(value));
String servletName = config.getServletName();
int dot = servletName.lastIndexOf('.');
if (dot >= 0)
servletName = servletName.substring(dot + 1);
qtp.setName(servletName);
executor=qtp;
}
client.setExecutor(executor);
value = config.getInitParameter("maxConnections");
if (value == null)
value = "256";
client.setMaxConnectionsPerDestination(Integer.parseInt(value));
value = config.getInitParameter("idleTimeout");
if (value == null)
value = "30000";
client.setIdleTimeout(Long.parseLong(value));
value = config.getInitParameter("timeout");
if (value == null)
value = "60000";
_timeout = Long.parseLong(value);
value = config.getInitParameter("requestBufferSize");
if (value != null)
client.setRequestBufferSize(Integer.parseInt(value));
value = config.getInitParameter("responseBufferSize");
if (value != null)
client.setResponseBufferSize(Integer.parseInt(value));
try
{
client.start();
// Content must not be decoded, otherwise the client gets confused
client.getContentDecoderFactories().clear();
return client;
}
catch (Exception x)
{
throw new ServletException(x);
}
}
/**
* @return a new HttpClient instance
*/
protected HttpClient newHttpClient()
{
return new HttpClient();
}
private Set<String> parseList(String list)
{
Set<String> result = new HashSet<>();
String[] hosts = list.split(",");
for (String host : hosts)
{
host = host.trim();
if (host.length() == 0)
continue;
result.add(host);
}
return result;
}
/**
* Checks the given {@code host} and {@code port} against whitelist and blacklist.
*
* @param host the host to check
* @param port the port to check
* @return true if it is allowed to be proxy to the given host and port
*/
public boolean validateDestination(String host, int port)
{
String hostPort = host + ":" + port;
if (!_whiteList.isEmpty())
{
if (!_whiteList.contains(hostPort))
{
if (_log.isDebugEnabled())
_log.debug("Host {}:{} not whitelisted", host, port);
return false;
}
}
if (!_blackList.isEmpty())
{
if (_blackList.contains(hostPort))
{
if (_log.isDebugEnabled())
_log.debug("Host {}:{} blacklisted", host, port);
return false;
}
}
return true;
}
@Override @Override
protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException
{ {
@ -404,107 +87,25 @@ public class ProxyServlet extends HttpServlet
return; return;
} }
final Request proxyRequest = _client.newRequest(rewrittenURI) final Request proxyRequest = getHttpClient().newRequest(rewrittenURI)
.method(request.getMethod()) .method(request.getMethod())
.version(HttpVersion.fromString(request.getProtocol())); .version(HttpVersion.fromString(request.getProtocol()));
// Copy headers. copyHeaders(request, proxyRequest);
// Any header listed by the Connection header must be removed: addProxyHeaders(request, proxyRequest);
// http://tools.ietf.org/html/rfc7230#section-6.1.
Set<String> hopHeaders = null;
Enumeration<String> connectionHeaders = request.getHeaders(HttpHeader.CONNECTION.asString());
while (connectionHeaders.hasMoreElements())
{
String value = connectionHeaders.nextElement();
String[] values = value.split(",");
for (String name : values)
{
name = name.trim().toLowerCase(Locale.ENGLISH);
if (hopHeaders == null)
hopHeaders = new HashSet<>();
hopHeaders.add(name);
}
}
boolean hasContent = request.getContentLength() > 0 || request.getContentType() != null;
for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
{
String headerName = headerNames.nextElement();
String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
if (HttpHeader.TRANSFER_ENCODING.is(headerName))
hasContent = true;
if (_hostHeader != null && HttpHeader.HOST.is(headerName))
continue;
// Remove hop-by-hop headers.
if (HOP_HEADERS.contains(lowerHeaderName))
continue;
if (hopHeaders != null && hopHeaders.contains(lowerHeaderName))
continue;
for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
{
String headerValue = headerValues.nextElement();
if (headerValue != null)
proxyRequest.header(headerName, headerValue);
}
}
// Force the Host header if configured
if (_hostHeader != null)
proxyRequest.header(HttpHeader.HOST, _hostHeader);
// Add proxy headers
addViaHeader(proxyRequest);
addXForwardedHeaders(proxyRequest, request);
final AsyncContext asyncContext = request.startAsync(); final AsyncContext asyncContext = request.startAsync();
// We do not timeout the continuation, but the proxy request // We do not timeout the continuation, but the proxy request
asyncContext.setTimeout(0); asyncContext.setTimeout(0);
proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS); proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS);
if (hasContent) if (hasContent(request))
proxyRequest.content(proxyRequestContent(proxyRequest, request)); proxyRequest.content(proxyRequestContent(proxyRequest, request));
customizeProxyRequest(proxyRequest, request); customizeProxyRequest(proxyRequest, request);
if (_log.isDebugEnabled()) sendProxyRequest(request, response, proxyRequest);
{
StringBuilder builder = new StringBuilder(request.getMethod());
builder.append(" ").append(request.getRequestURI());
String query = request.getQueryString();
if (query != null)
builder.append("?").append(query);
builder.append(" ").append(request.getProtocol()).append("\r\n");
for (Enumeration<String> headerNames = request.getHeaderNames(); headerNames.hasMoreElements();)
{
String headerName = headerNames.nextElement();
builder.append(headerName).append(": ");
for (Enumeration<String> headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();)
{
String headerValue = headerValues.nextElement();
if (headerValue != null)
builder.append(headerValue);
if (headerValues.hasMoreElements())
builder.append(",");
}
builder.append("\r\n");
}
builder.append("\r\n");
_log.debug("{} proxying to upstream:{}{}{}{}",
requestId,
System.lineSeparator(),
builder,
proxyRequest,
System.lineSeparator(),
proxyRequest.getHeaders().toString().trim());
}
proxyRequest.send(newProxyResponseListener(request, response));
} }
protected ContentProvider proxyRequestContent(final Request proxyRequest, final HttpServletRequest request) throws IOException protected ContentProvider proxyRequestContent(final Request proxyRequest, final HttpServletRequest request) throws IOException
@ -524,39 +125,29 @@ public class ProxyServlet extends HttpServlet
proxyRequest.abort(failure); proxyRequest.abort(failure);
} }
/**
* @deprecated use {@link #onProxyRewriteFailed(HttpServletRequest, HttpServletResponse)}
*/
@Deprecated
protected void onRewriteFailed(HttpServletRequest request, HttpServletResponse response) throws IOException protected void onRewriteFailed(HttpServletRequest request, HttpServletResponse response) throws IOException
{ {
response.sendError(HttpServletResponse.SC_FORBIDDEN); onProxyRewriteFailed(request, response);
}
protected Request addViaHeader(Request proxyRequest)
{
return proxyRequest.header(HttpHeader.VIA, "http/1.1 " + getViaHost());
}
protected void addXForwardedHeaders(Request proxyRequest, HttpServletRequest request)
{
proxyRequest.header(HttpHeader.X_FORWARDED_FOR, request.getRemoteAddr());
proxyRequest.header(HttpHeader.X_FORWARDED_PROTO, request.getScheme());
proxyRequest.header(HttpHeader.X_FORWARDED_HOST, request.getHeader(HttpHeader.HOST.asString()));
proxyRequest.header(HttpHeader.X_FORWARDED_SERVER, request.getLocalName());
} }
/**
* @deprecated use {@link #onServerResponseHeaders(HttpServletRequest, HttpServletResponse, Response)}
*/
@Deprecated
protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
{ {
for (HttpField field : proxyResponse.getHeaders()) onServerResponseHeaders(request, response, proxyResponse);
{ }
String headerName = field.getName();
String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH);
if (HOP_HEADERS.contains(lowerHeaderName))
continue;
String newHeaderValue = filterResponseHeader(request, headerName, field.getValue()); // TODO: remove in Jetty 9.3, only here for backward compatibility.
if (newHeaderValue == null || newHeaderValue.trim().length() == 0) @Override
continue; protected String filterServerResponseHeader(HttpServletRequest clientRequest, String headerName, String headerValue)
{
response.addHeader(headerName, newHeaderValue); return filterResponseHeader(clientRequest, headerName, headerValue);
}
} }
protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length, Callback callback) protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length, Callback callback)
@ -574,71 +165,38 @@ public class ProxyServlet extends HttpServlet
} }
} }
/**
* @deprecated Use {@link #onProxyResponseSuccess(HttpServletRequest, HttpServletResponse, Response)}
*/
@Deprecated
protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
{ {
if (_log.isDebugEnabled()) onProxyResponseSuccess(request, response, proxyResponse);
_log.debug("{} proxying successful", getRequestId(request));
AsyncContext asyncContext = request.getAsyncContext();
asyncContext.complete();
}
protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure)
{
if (_log.isDebugEnabled())
_log.debug(getRequestId(request) + " proxying failed", failure);
if (response.isCommitted())
{
try
{
// Use Jetty specific behavior to close connection.
response.sendError(-1);
AsyncContext asyncContext = request.getAsyncContext();
asyncContext.complete();
}
catch (IOException x)
{
if (_log.isDebugEnabled())
_log.debug(getRequestId(request) + " could not close the connection", failure);
}
}
else
{
response.resetBuffer();
if (failure instanceof TimeoutException)
response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT);
else
response.setStatus(HttpServletResponse.SC_BAD_GATEWAY);
response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString());
AsyncContext asyncContext = request.getAsyncContext();
asyncContext.complete();
}
}
protected int getRequestId(HttpServletRequest request)
{
return System.identityHashCode(request);
}
protected URI rewriteURI(HttpServletRequest request)
{
if (!validateDestination(request.getServerName(), request.getServerPort()))
return null;
StringBuffer uri = request.getRequestURL();
String query = request.getQueryString();
if (query != null)
uri.append("?").append(query);
return URI.create(uri.toString());
} }
/** /**
* Extension point for subclasses to customize the proxy request. * @deprecated Use {@link #onProxyResponseFailure(HttpServletRequest, HttpServletResponse, Response, Throwable)}
* The default implementation does nothing.
*
* @param proxyRequest the proxy request to customize
* @param request the request to be proxied
*/ */
@Deprecated
protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure)
{
onProxyResponseFailure(request, response, proxyResponse, failure);
}
/**
* @deprecated use {@link #rewriteTarget(HttpServletRequest)}
*/
@Deprecated
protected URI rewriteURI(HttpServletRequest request)
{
String newTarget = rewriteTarget(request);
return newTarget == null ? null : URI.create(newTarget);
}
/**
* @deprecated use {@link #sendProxyRequest(HttpServletRequest, HttpServletResponse, Request)}
*/
@Deprecated
protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request) protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request)
{ {
} }
@ -766,33 +324,6 @@ public class ProxyServlet extends HttpServlet
public void onHeaders(Response proxyResponse) public void onHeaders(Response proxyResponse)
{ {
onResponseHeaders(request, response, proxyResponse); onResponseHeaders(request, response, proxyResponse);
if (_log.isDebugEnabled())
{
StringBuilder builder = new StringBuilder("\r\n");
builder.append(request.getProtocol()).append(" ").append(response.getStatus()).append(" ").append(proxyResponse.getReason()).append("\r\n");
for (String headerName : response.getHeaderNames())
{
builder.append(headerName).append(": ");
for (Iterator<String> headerValues = response.getHeaders(headerName).iterator(); headerValues.hasNext();)
{
String headerValue = headerValues.next();
if (headerValue != null)
builder.append(headerValue);
if (headerValues.hasNext())
builder.append(",");
}
builder.append("\r\n");
}
_log.debug("{} proxying to downstream:{}{}{}{}{}",
getRequestId(request),
System.lineSeparator(),
proxyResponse,
System.lineSeparator(),
proxyResponse.getHeaders().toString().trim(),
System.lineSeparator(),
builder);
}
} }
@Override @Override

View File

@ -0,0 +1,850 @@
//
// ========================================================================
// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.proxy;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.URLDecoder;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPOutputStream;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.client.HttpProxy;
import org.eclipse.jetty.client.api.ContentProvider;
import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.api.Request;
import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.api.Result;
import org.eclipse.jetty.client.util.BytesContentProvider;
import org.eclipse.jetty.client.util.DeferredContentProvider;
import org.eclipse.jetty.client.util.FutureResponseListener;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.server.HttpConfiguration;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import org.eclipse.jetty.toolchain.test.TestTracker;
import org.eclipse.jetty.util.Utf8StringBuilder;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.After;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
public class AsyncMiddleManServletTest
{
private static final Logger LOG = Log.getLogger(AsyncMiddleManServletTest.class);
@Rule
public final TestTracker tracker = new TestTracker();
private HttpClient client;
private Server proxy;
private ServerConnector proxyConnector;
private ServletContextHandler proxyContext;
private Server server;
private ServerConnector serverConnector;
private void prepareProxy(HttpServlet proxyServlet) throws Exception
{
prepareProxy(proxyServlet, new HashMap<String, String>());
}
private void prepareProxy(HttpServlet proxyServlet, Map<String, String> initParams) throws Exception
{
QueuedThreadPool proxyPool = new QueuedThreadPool();
proxyPool.setName("proxy");
proxy = new Server(proxyPool);
HttpConfiguration configuration = new HttpConfiguration();
configuration.setSendDateHeader(false);
configuration.setSendServerVersion(false);
String value = initParams.get("outputBufferSize");
if (value != null)
configuration.setOutputBufferSize(Integer.valueOf(value));
proxyConnector = new ServerConnector(proxy, new HttpConnectionFactory(configuration));
proxy.addConnector(proxyConnector);
proxyContext = new ServletContextHandler(proxy, "/", true, false);
ServletHolder proxyServletHolder = new ServletHolder(proxyServlet);
proxyServletHolder.setInitParameters(initParams);
proxyContext.addServlet(proxyServletHolder, "/*");
proxy.start();
}
private void prepareClient() throws Exception
{
QueuedThreadPool clientPool = new QueuedThreadPool();
clientPool.setName("client");
client = new HttpClient();
client.setExecutor(clientPool);
client.getProxyConfiguration().getProxies().add(new HttpProxy("localhost", proxyConnector.getLocalPort()));
client.start();
}
private void prepareServer(HttpServlet servlet) throws Exception
{
QueuedThreadPool serverPool = new QueuedThreadPool();
serverPool.setName("server");
server = new Server(serverPool);
serverConnector = new ServerConnector(server);
server.addConnector(serverConnector);
ServletContextHandler appCtx = new ServletContextHandler(server, "/", true, false);
ServletHolder appServletHolder = new ServletHolder(servlet);
appCtx.addServlet(appServletHolder, "/*");
server.start();
}
@After
public void disposeProxy() throws Exception
{
client.stop();
proxy.stop();
}
@After
public void disposeServer() throws Exception
{
server.stop();
}
@Test
public void testClientRequestSmallContentKnownLengthGzipped() throws Exception
{
// Lengths smaller than the buffer sizes preserve the Content-Length header.
testClientRequestContentKnownLengthGzipped(1024, false);
}
@Test
public void testClientRequestLargeContentKnownLengthGzipped() throws Exception
{
// Lengths bigger than the buffer sizes will force chunked mode.
testClientRequestContentKnownLengthGzipped(1024 * 1024, true);
}
private void testClientRequestContentKnownLengthGzipped(int length, final boolean expectChunked) throws Exception
{
byte[] bytes = new byte[length];
new Random().nextBytes(bytes);
prepareServer(new EchoHttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
String transferEncoding = request.getHeader(HttpHeader.TRANSFER_ENCODING.asString());
if (expectChunked)
Assert.assertNotNull(transferEncoding);
else
Assert.assertNull(transferEncoding);
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
super.service(request, response);
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new GZIPContentTransformer(ContentTransformer.IDENTITY);
}
});
prepareClient();
byte[] gzipBytes = gzip(bytes);
ContentProvider gzipContent = new BytesContentProvider(gzipBytes);
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.header(HttpHeader.CONTENT_ENCODING, "gzip")
.content(gzipContent)
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(200, response.getStatus());
Assert.assertArrayEquals(bytes, response.getContent());
}
@Test
public void testServerResponseContentKnownLengthGzipped() throws Exception
{
byte[] bytes = new byte[1024];
new Random().nextBytes(bytes);
final byte[] gzipBytes = gzip(bytes);
prepareServer(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
response.getOutputStream().write(gzipBytes);
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return new GZIPContentTransformer(ContentTransformer.IDENTITY);
}
});
prepareClient();
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(200, response.getStatus());
Assert.assertArrayEquals(bytes, response.getContent());
}
@Test
public void testTransformUpstreamAndDownstreamKnownContentLengthGzipped() throws Exception
{
String data = "<a href=\"http://google.com\">Google</a>";
byte[] bytes = data.getBytes(StandardCharsets.UTF_8);
prepareServer(new EchoHttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
super.service(request, response);
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new GZIPContentTransformer(new HrefTransformer.Client());
}
@Override
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return new GZIPContentTransformer(new HrefTransformer.Server());
}
});
prepareClient();
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.header(HttpHeader.CONTENT_ENCODING, "gzip")
.content(new BytesContentProvider(gzip(bytes)))
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(200, response.getStatus());
Assert.assertArrayEquals(bytes, response.getContent());
}
@Test
public void testManySequentialTransformations() throws Exception
{
for (int i = 0; i < 8; ++i)
testTransformUpstreamAndDownstreamKnownContentLengthGzipped();
}
@Test
public void testUpstreamTransformationBufferedGzipped() throws Exception
{
prepareServer(new EchoHttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
super.service(request, response);
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new GZIPContentTransformer(new BufferingContentTransformer());
}
});
prepareClient();
DeferredContentProvider content = new DeferredContentProvider();
Request request = client.newRequest("localhost", serverConnector.getLocalPort());
FutureResponseListener listener = new FutureResponseListener(request);
request.header(HttpHeader.CONTENT_ENCODING, "gzip")
.content(content)
.send(listener);
byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8);
content.offer(ByteBuffer.wrap(gzip(bytes)));
sleep(1000);
content.close();
ContentResponse response = listener.get(5, TimeUnit.SECONDS);
Assert.assertEquals(200, response.getStatus());
Assert.assertArrayEquals(bytes, response.getContent());
}
@Test
public void testDownstreamTransformationBufferedGzipped() throws Exception
{
prepareServer(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
ServletInputStream input = request.getInputStream();
ServletOutputStream output = response.getOutputStream();
int read;
while ((read = input.read()) >= 0)
{
output.write(read);
output.flush();
}
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return new GZIPContentTransformer(new BufferingContentTransformer());
}
});
prepareClient();
byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8);
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.header(HttpHeader.CONTENT_ENCODING, "gzip")
.content(new BytesContentProvider(gzip(bytes)))
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(200, response.getStatus());
Assert.assertArrayEquals(bytes, response.getContent());
}
@Test
public void testDiscardUpstreamAndDownstreamKnownContentLengthGzipped() throws Exception
{
final byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8);
prepareServer(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
Assert.assertEquals(-1, request.getInputStream().read());
response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip");
response.getOutputStream().write(gzip(bytes));
}
});
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new GZIPContentTransformer(new DiscardContentTransformer());
}
@Override
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return new GZIPContentTransformer(new DiscardContentTransformer());
}
});
prepareClient();
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.header(HttpHeader.CONTENT_ENCODING, "gzip")
.content(new BytesContentProvider(gzip(bytes)))
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(200, response.getStatus());
Assert.assertEquals(0, response.getContent().length);
}
@Test
public void testUpstreamTransformationThrowsBeforeCommittingProxyRequest() throws Exception
{
prepareServer(new EchoHttpServlet());
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new ContentTransformer()
{
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
throw new NullPointerException("explicitly_thrown_by_test");
}
};
}
});
prepareClient();
byte[] bytes = new byte[1024];
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.content(new BytesContentProvider(bytes))
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(500, response.getStatus());
}
@Test
public void testUpstreamTransformationThrowsAfterCommittingProxyRequest() throws Exception
{
prepareServer(new EchoHttpServlet());
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest)
{
return new ContentTransformer()
{
private int count;
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
if (++count < 2)
output.add(input);
else
throw new NullPointerException("explicitly_thrown_by_test");
}
};
}
});
prepareClient();
final CountDownLatch latch = new CountDownLatch(1);
DeferredContentProvider content = new DeferredContentProvider();
client.newRequest("localhost", serverConnector.getLocalPort())
.content(content)
.send(new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
if (result.isSucceeded() && result.getResponse().getStatus() == 502)
latch.countDown();
}
});
content.offer(ByteBuffer.allocate(512));
sleep(1000);
content.offer(ByteBuffer.allocate(512));
content.close();
Assert.assertTrue(latch.await(5, TimeUnit.SECONDS));
}
@Test
public void testDownstreamTransformationThrowsAtOnContent() throws Exception
{
testDownstreamTransformationThrows(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
// To trigger the test failure we need that onContent()
// is called twice, so the second time the test throws.
ServletOutputStream output = response.getOutputStream();
output.write(new byte[512]);
output.flush();
output.write(new byte[512]);
output.flush();
}
});
}
@Test
public void testDownstreamTransformationThrowsAtOnSuccess() throws Exception
{
testDownstreamTransformationThrows(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
// To trigger the test failure we need that onContent()
// is called only once, so the the test throws from onSuccess().
ServletOutputStream output = response.getOutputStream();
output.write(new byte[512]);
output.flush();
}
});
}
private void testDownstreamTransformationThrows(HttpServlet serverServlet) throws Exception
{
prepareServer(serverServlet);
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse)
{
return new ContentTransformer()
{
private int count;
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
if (++count < 2)
output.add(input);
else
throw new NullPointerException("explicitly_thrown_by_test");
}
};
}
});
prepareClient();
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(502, response.getStatus());
}
@Test
public void testClientRequestReadFailsOnFirstRead() throws Exception
{
prepareServer(new EchoHttpServlet());
prepareProxy(new AsyncMiddleManServlet()
{
@Override
protected int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException
{
throw new IOException("explicitly_thrown_by_test");
}
});
prepareClient();
final CountDownLatch latch = new CountDownLatch(1);
DeferredContentProvider content = new DeferredContentProvider();
client.newRequest("localhost", serverConnector.getLocalPort())
.content(content)
.send(new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
System.err.println(result);
if (result.getResponse().getStatus() == 500)
latch.countDown();
}
});
content.offer(ByteBuffer.allocate(512));
sleep(1000);
content.offer(ByteBuffer.allocate(512));
content.close();
Assert.assertTrue(latch.await(5, TimeUnit.SECONDS));
}
@Test
public void testClientRequestReadFailsOnSecondRead() throws Exception
{
prepareServer(new EchoHttpServlet());
prepareProxy(new AsyncMiddleManServlet()
{
private int count;
@Override
protected int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException
{
if (++count < 2)
return super.readClientRequestContent(input, buffer);
else
throw new IOException("explicitly_thrown_by_test");
}
});
prepareClient();
final CountDownLatch latch = new CountDownLatch(1);
DeferredContentProvider content = new DeferredContentProvider();
client.newRequest("localhost", serverConnector.getLocalPort())
.content(content)
.send(new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
if (result.getResponse().getStatus() == 502)
latch.countDown();
}
});
content.offer(ByteBuffer.allocate(512));
sleep(1000);
content.offer(ByteBuffer.allocate(512));
content.close();
Assert.assertTrue(latch.await(5, TimeUnit.SECONDS));
}
@Test
public void testProxyResponseWriteFailsOnFirstWrite() throws Exception
{
testProxyResponseWriteFails(1);
}
@Test
public void testProxyResponseWriteFailsOnSecondWrite() throws Exception
{
testProxyResponseWriteFails(2);
}
private void testProxyResponseWriteFails(final int writeCount) throws Exception
{
prepareServer(new HttpServlet()
{
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
{
ServletOutputStream output = response.getOutputStream();
output.write(new byte[512]);
output.flush();
output.write(new byte[512]);
}
});
prepareProxy(new AsyncMiddleManServlet()
{
private int count;
@Override
protected void writeProxyResponseContent(ServletOutputStream output, ByteBuffer content) throws IOException
{
if (++count < writeCount)
super.writeProxyResponseContent(output, content);
else
throw new IOException("explicitly_thrown_by_test");
}
});
prepareClient();
ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort())
.timeout(5, TimeUnit.SECONDS)
.send();
Assert.assertEquals(502, response.getStatus());
}
private void sleep(long delay) throws IOException
{
try
{
Thread.sleep(delay);
}
catch (InterruptedException x)
{
throw new InterruptedIOException();
}
}
private byte[] gzip(byte[] bytes) throws IOException
{
ByteArrayOutputStream out = new ByteArrayOutputStream();
try (GZIPOutputStream gzipOut = new GZIPOutputStream(out))
{
gzipOut.write(bytes);
}
return out.toByteArray();
}
private static abstract class HrefTransformer implements AsyncMiddleManServlet.ContentTransformer
{
private static final String PREFIX = "http://localhost/q=";
private final HrefParser parser = new HrefParser();
private final List<ByteBuffer> matches = new ArrayList<>();
private boolean matching;
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
int position = input.position();
while (input.hasRemaining())
{
boolean match = parser.parse(input);
// Get the slice of what has been parsed so far.
int limit = input.limit();
input.limit(input.position());
input.position(position);
ByteBuffer slice = input.slice();
input.position(input.limit());
input.limit(limit);
position = input.position();
if (matching)
{
if (match)
{
ByteBuffer copy = ByteBuffer.allocate(slice.remaining());
copy.put(slice).flip();
matches.add(copy);
}
else
{
matching = false;
// Transform the matches.
Utf8StringBuilder builder = new Utf8StringBuilder();
for (ByteBuffer buffer : matches)
builder.append(buffer);
String transformed = transform(builder.toString());
output.add(ByteBuffer.wrap(transformed.getBytes(StandardCharsets.UTF_8)));
output.add(slice);
}
}
else
{
if (match)
{
matching = true;
ByteBuffer copy = ByteBuffer.allocate(slice.remaining());
copy.put(slice).flip();
matches.add(copy);
}
else
{
output.add(slice);
}
}
}
}
protected abstract String transform(String value) throws IOException;
private static class Client extends HrefTransformer
{
@Override
protected String transform(String value) throws IOException
{
String result = PREFIX + URLEncoder.encode(value, "UTF-8");
LOG.debug("{} -> {}", value, result);
return result;
}
}
private static class Server extends HrefTransformer
{
@Override
protected String transform(String value) throws IOException
{
String result = URLDecoder.decode(value.substring(PREFIX.length()), "UTF-8");
LOG.debug("{} <- {}", value, result);
return result;
}
}
}
private static class HrefParser
{
private final byte[] token = {'h', 'r', 'e', 'f', '=', '"'};
private int state;
private boolean parse(ByteBuffer buffer)
{
while (buffer.hasRemaining())
{
int current = buffer.get() & 0xFF;
if (state < token.length)
{
if (Character.toLowerCase(current) != token[state])
{
state = 0;
continue;
}
++state;
if (state == token.length)
return false;
}
else
{
// Look for the ending quote.
if (current == '"')
{
buffer.position(buffer.position() - 1);
state = 0;
return true;
}
}
}
return state == token.length;
}
}
private static class BufferingContentTransformer implements AsyncMiddleManServlet.ContentTransformer
{
private final List<ByteBuffer> buffers = new ArrayList<>();
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
if (input.hasRemaining())
{
ByteBuffer copy = ByteBuffer.allocate(input.remaining());
copy.put(input).flip();
buffers.add(copy);
}
if (finished)
{
Assert.assertFalse(buffers.isEmpty());
output.addAll(buffers);
buffers.clear();
}
}
}
private static class DiscardContentTransformer implements AsyncMiddleManServlet.ContentTransformer
{
@Override
public void transform(ByteBuffer input, boolean finished, List<ByteBuffer> output) throws IOException
{
}
}
}

View File

@ -104,7 +104,8 @@ public class ProxyServletTest
{ {
return Arrays.asList(new Object[][]{ return Arrays.asList(new Object[][]{
{ProxyServlet.class}, {ProxyServlet.class},
{AsyncProxyServlet.class} {AsyncProxyServlet.class},
{AsyncMiddleManServlet.class}
}); });
} }
@ -114,13 +115,13 @@ public class ProxyServletTest
private Server proxy; private Server proxy;
private ServerConnector proxyConnector; private ServerConnector proxyConnector;
private ServletContextHandler proxyContext; private ServletContextHandler proxyContext;
private ProxyServlet proxyServlet; private AbstractProxyServlet proxyServlet;
private Server server; private Server server;
private ServerConnector serverConnector; private ServerConnector serverConnector;
public ProxyServletTest(Class<?> proxyServletClass) throws Exception public ProxyServletTest(Class<?> proxyServletClass) throws Exception
{ {
this.proxyServlet = (ProxyServlet)proxyServletClass.newInstance(); this.proxyServlet = (AbstractProxyServlet)proxyServletClass.newInstance();
} }
private void prepareProxy() throws Exception private void prepareProxy() throws Exception
@ -823,12 +824,12 @@ public class ProxyServletTest
} }
@Override @Override
protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) protected void onProxyResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse)
{ {
byte[] content = temp.remove(request.getRequestURI()).toByteArray(); byte[] content = temp.remove(request.getRequestURI()).toByteArray();
ContentResponse cached = new HttpContentResponse(proxyResponse, content, null, null); ContentResponse cached = new HttpContentResponse(proxyResponse, content, null, null);
cache.put(request.getRequestURI(), cached); cache.put(request.getRequestURI(), cached);
super.onResponseSuccess(request, response, proxyResponse); super.onProxyResponseSuccess(request, response, proxyResponse);
} }
}; };
prepareProxy(); prepareProxy();

View File

@ -40,8 +40,7 @@ import org.eclipse.jetty.util.security.Credential;
/* ------------------------------------------------------------ */ /* ------------------------------------------------------------ */
/** /**
* HashMapped User Realm with JDBC as data source. JDBCLoginService extends * HashMapped User Realm with JDBC as data source.
* HashULoginService and adds a method to fetch user information from database.
* The login() method checks the inherited Map for the user. If the user is not * The login() method checks the inherited Map for the user. If the user is not
* found, it will fetch details from the database and populate the inherited * found, it will fetch details from the database and populate the inherited
* Map. It then calls the superclass login() method to perform the actual * Map. It then calls the superclass login() method to perform the actual

View File

@ -18,7 +18,7 @@
package org.eclipse.jetty.util; package org.eclipse.jetty.util;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference;
/** /**
* A callback to be used by driver code that needs to know whether the callback has been * A callback to be used by driver code that needs to know whether the callback has been
@ -57,20 +57,63 @@ import java.util.concurrent.atomic.AtomicBoolean;
*/ */
public abstract class CompletableCallback implements Callback public abstract class CompletableCallback implements Callback
{ {
private final AtomicBoolean completed = new AtomicBoolean(); private final AtomicReference<State> state = new AtomicReference<>(State.IDLE);
@Override @Override
public void succeeded() public void succeeded()
{ {
if (!tryComplete()) while (true)
resume(); {
State current = state.get();
switch (current)
{
case IDLE:
{
if (state.compareAndSet(current, State.SUCCEEDED))
return;
break;
}
case COMPLETED:
{
if (state.compareAndSet(current, State.SUCCEEDED))
{
resume();
return;
}
break;
}
default:
{
throw new IllegalStateException(current.toString());
}
}
}
} }
@Override @Override
public void failed(Throwable x) public void failed(Throwable x)
{ {
if (!tryComplete()) while (true)
abort(x); {
State current = state.get();
switch (current)
{
case IDLE:
case COMPLETED:
{
if (state.compareAndSet(current, State.FAILED))
{
abort(x);
return;
}
break;
}
default:
{
throw new IllegalStateException(current.toString());
}
}
}
} }
/** /**
@ -80,8 +123,7 @@ public abstract class CompletableCallback implements Callback
public abstract void resume(); public abstract void resume();
/** /**
* Callback method invoked when this callback is failed * Callback method invoked when this callback is failed.
* <em>after</em> a first call to {@link #tryComplete()}.
*/ */
public abstract void abort(Throwable failure); public abstract void abort(Throwable failure);
@ -95,6 +137,32 @@ public abstract class CompletableCallback implements Callback
*/ */
public boolean tryComplete() public boolean tryComplete()
{ {
return completed.compareAndSet(false, true); while (true)
{
State current = state.get();
switch (current)
{
case IDLE:
{
if (state.compareAndSet(current, State.COMPLETED))
return true;
break;
}
case SUCCEEDED:
case FAILED:
{
return false;
}
default:
{
throw new IllegalStateException(current.toString());
}
}
}
}
private enum State
{
IDLE, SUCCEEDED, FAILED, COMPLETED
} }
} }