diff --git a/jetty-client/src/main/java/org/eclipse/jetty/client/HttpExchange.java b/jetty-client/src/main/java/org/eclipse/jetty/client/HttpExchange.java index d271f9447ed..89d2446aeb1 100644 --- a/jetty-client/src/main/java/org/eclipse/jetty/client/HttpExchange.java +++ b/jetty-client/src/main/java/org/eclipse/jetty/client/HttpExchange.java @@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import org.eclipse.jetty.client.api.Request; import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.util.log.Log; @@ -60,7 +59,7 @@ public class HttpExchange return request.getConversation(); } - public Request getRequest() + public HttpRequest getRequest() { return request; } diff --git a/jetty-client/src/main/java/org/eclipse/jetty/client/util/DeferredContentProvider.java b/jetty-client/src/main/java/org/eclipse/jetty/client/util/DeferredContentProvider.java index 4794802b585..43ec5748053 100644 --- a/jetty-client/src/main/java/org/eclipse/jetty/client/util/DeferredContentProvider.java +++ b/jetty-client/src/main/java/org/eclipse/jetty/client/util/DeferredContentProvider.java @@ -94,6 +94,7 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback, private final AtomicReference listener = new AtomicReference<>(); private final DeferredContentProviderIterator iterator = new DeferredContentProviderIterator(); private final AtomicBoolean closed = new AtomicBoolean(); + private long length = -1; private int size; private Throwable failure; @@ -114,12 +115,23 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback, if (!this.listener.compareAndSet(null, listener)) throw new IllegalStateException(String.format("The same %s instance cannot be used in multiple requests", AsyncContentProvider.class.getName())); + + if (isClosed()) + { + synchronized (lock) + { + long total = 0; + for (AsyncChunk chunk : chunks) + total += chunk.buffer.remaining(); + length = total; + } + } } @Override public long getLength() { - return -1; + return length; } /** @@ -200,6 +212,11 @@ public class DeferredContentProvider implements AsyncContentProvider, Callback, offer(CLOSE); } + public boolean isClosed() + { + return closed.get(); + } + @Override public void succeeded() { diff --git a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AbstractProxyServlet.java b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AbstractProxyServlet.java new file mode 100644 index 00000000000..ffbe69ad0cb --- /dev/null +++ b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AbstractProxyServlet.java @@ -0,0 +1,610 @@ +// +// ======================================================================== +// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.proxy; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeoutException; +import javax.servlet.AsyncContext; +import javax.servlet.ServletConfig; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpHeader; +import org.eclipse.jetty.http.HttpHeaderValue; +import org.eclipse.jetty.util.HttpCookieStore; +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; +import org.eclipse.jetty.util.thread.QueuedThreadPool; + +public abstract class AbstractProxyServlet extends HttpServlet +{ + protected static final Set HOP_HEADERS; + static + { + Set hopHeaders = new HashSet<>(); + hopHeaders.add("connection"); + hopHeaders.add("keep-alive"); + hopHeaders.add("proxy-authorization"); + hopHeaders.add("proxy-authenticate"); + hopHeaders.add("proxy-connection"); + hopHeaders.add("transfer-encoding"); + hopHeaders.add("te"); + hopHeaders.add("trailer"); + hopHeaders.add("upgrade"); + HOP_HEADERS = Collections.unmodifiableSet(hopHeaders); + } + + private final Set _whiteList = new HashSet<>(); + private final Set _blackList = new HashSet<>(); + protected Logger _log; + private String _hostHeader; + private String _viaHost; + private HttpClient _client; + private long _timeout; + + @Override + public void init() throws ServletException + { + _log = createLogger(); + + ServletConfig config = getServletConfig(); + + _hostHeader = config.getInitParameter("hostHeader"); + + _viaHost = config.getInitParameter("viaHost"); + if (_viaHost == null) + _viaHost = viaHost(); + + try + { + _client = createHttpClient(); + + // Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES + getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client); + + String whiteList = config.getInitParameter("whiteList"); + if (whiteList != null) + getWhiteListHosts().addAll(parseList(whiteList)); + + String blackList = config.getInitParameter("blackList"); + if (blackList != null) + getBlackListHosts().addAll(parseList(blackList)); + } + catch (Exception e) + { + throw new ServletException(e); + } + } + + @Override + public void destroy() + { + try + { + _client.stop(); + } + catch (Exception x) + { + if (_log.isDebugEnabled()) + _log.debug(x); + } + } + + public String getHostHeader() + { + return _hostHeader; + } + + public String getViaHost() + { + return _viaHost; + } + + private static String viaHost() + { + try + { + return InetAddress.getLocalHost().getHostName(); + } + catch (UnknownHostException x) + { + return "localhost"; + } + } + + public long getTimeout() + { + return _timeout; + } + + public void setTimeout(long timeout) + { + this._timeout = timeout; + } + + public Set getWhiteListHosts() + { + return _whiteList; + } + + public Set getBlackListHosts() + { + return _blackList; + } + + /** + * @return a logger instance with a name derived from this servlet's name. + */ + protected Logger createLogger() + { + String servletName = getServletConfig().getServletName(); + servletName = servletName.replace('-', '.'); + if ((getClass().getPackage() != null) && !servletName.startsWith(getClass().getPackage().getName())) + { + servletName = getClass().getName() + "." + servletName; + } + return Log.getLogger(servletName); + } + + /** + * Creates a {@link HttpClient} instance, configured with init parameters of this servlet. + *

+ * The init parameters used to configure the {@link HttpClient} instance are: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
init-paramdefaultdescription
maxThreads256The max number of threads of HttpClient's Executor. If not set, or set to the value of "-", then the + * Jetty server thread pool will be used.
maxConnections32768The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}
idleTimeout30000The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}
timeout60000The total timeout in milliseconds, see {@link Request#timeout(long, java.util.concurrent.TimeUnit)}
requestBufferSizeHttpClient's defaultThe request buffer size, see {@link HttpClient#setRequestBufferSize(int)}
responseBufferSizeHttpClient's defaultThe response buffer size, see {@link HttpClient#setResponseBufferSize(int)}
+ * + * @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration} + * @throws ServletException if the {@link HttpClient} cannot be created + */ + protected HttpClient createHttpClient() throws ServletException + { + ServletConfig config = getServletConfig(); + + HttpClient client = newHttpClient(); + + // Redirects must be proxied as is, not followed + client.setFollowRedirects(false); + + // Must not store cookies, otherwise cookies of different clients will mix + client.setCookieStore(new HttpCookieStore.Empty()); + + Executor executor; + String value = config.getInitParameter("maxThreads"); + if (value == null || "-".equals(value)) + { + executor = (Executor)getServletContext().getAttribute("org.eclipse.jetty.server.Executor"); + if (executor==null) + throw new IllegalStateException("No server executor for proxy"); + } + else + { + QueuedThreadPool qtp= new QueuedThreadPool(Integer.parseInt(value)); + String servletName = config.getServletName(); + int dot = servletName.lastIndexOf('.'); + if (dot >= 0) + servletName = servletName.substring(dot + 1); + qtp.setName(servletName); + executor=qtp; + } + + client.setExecutor(executor); + + value = config.getInitParameter("maxConnections"); + if (value == null) + value = "256"; + client.setMaxConnectionsPerDestination(Integer.parseInt(value)); + + value = config.getInitParameter("idleTimeout"); + if (value == null) + value = "30000"; + client.setIdleTimeout(Long.parseLong(value)); + + value = config.getInitParameter("timeout"); + if (value == null) + value = "60000"; + _timeout = Long.parseLong(value); + + value = config.getInitParameter("requestBufferSize"); + if (value != null) + client.setRequestBufferSize(Integer.parseInt(value)); + + value = config.getInitParameter("responseBufferSize"); + if (value != null) + client.setResponseBufferSize(Integer.parseInt(value)); + + try + { + client.start(); + + // Content must not be decoded, otherwise the client gets confused + client.getContentDecoderFactories().clear(); + + return client; + } + catch (Exception x) + { + throw new ServletException(x); + } + } + + /** + * @return a new HttpClient instance + */ + protected HttpClient newHttpClient() + { + return new HttpClient(); + } + + protected HttpClient getHttpClient() + { + return _client; + } + + private Set parseList(String list) + { + Set result = new HashSet<>(); + String[] hosts = list.split(","); + for (String host : hosts) + { + host = host.trim(); + if (host.length() == 0) + continue; + result.add(host); + } + return result; + } + + /** + * Checks the given {@code host} and {@code port} against whitelist and blacklist. + * + * @param host the host to check + * @param port the port to check + * @return true if it is allowed to be proxy to the given host and port + */ + public boolean validateDestination(String host, int port) + { + String hostPort = host + ":" + port; + if (!_whiteList.isEmpty()) + { + if (!_whiteList.contains(hostPort)) + { + if (_log.isDebugEnabled()) + _log.debug("Host {}:{} not whitelisted", host, port); + return false; + } + } + if (!_blackList.isEmpty()) + { + if (_blackList.contains(hostPort)) + { + if (_log.isDebugEnabled()) + _log.debug("Host {}:{} blacklisted", host, port); + return false; + } + } + return true; + } + + protected String rewriteTarget(HttpServletRequest clientRequest) + { + if (!validateDestination(clientRequest.getServerName(), clientRequest.getServerPort())) + return null; + + StringBuffer target = clientRequest.getRequestURL(); + String query = clientRequest.getQueryString(); + if (query != null) + target.append("?").append(query); + return target.toString(); + } + + /** + *

Callback method invoked when the URI rewrite performed + * in {@link #rewriteTarget(HttpServletRequest)} returns null + * indicating that no rewrite can be performed.

+ *

It is possible to use blocking API in this method, + * like {@link HttpServletResponse#sendError(int)}.

+ * + * @param clientRequest the client request + * @param clientResponse the client response + */ + protected void onProxyRewriteFailed(HttpServletRequest clientRequest, HttpServletResponse clientResponse) + { + clientResponse.setStatus(HttpServletResponse.SC_FORBIDDEN); + } + + protected boolean hasContent(HttpServletRequest clientRequest) + { + return clientRequest.getContentLength() > 0 || + clientRequest.getContentType() != null || + clientRequest.getHeader(HttpHeader.TRANSFER_ENCODING.asString()) != null; + } + + protected void copyHeaders(HttpServletRequest clientRequest, Request proxyRequest) + { + Set headersToRemove = findConnectionHeaders(clientRequest); + + for (Enumeration headerNames = clientRequest.getHeaderNames(); headerNames.hasMoreElements();) + { + String headerName = headerNames.nextElement(); + String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); + + if (_hostHeader != null && HttpHeader.HOST.is(headerName)) + continue; + + // Remove hop-by-hop headers. + if (HOP_HEADERS.contains(lowerHeaderName)) + continue; + if (headersToRemove != null && headersToRemove.contains(lowerHeaderName)) + continue; + + for (Enumeration headerValues = clientRequest.getHeaders(headerName); headerValues.hasMoreElements();) + { + String headerValue = headerValues.nextElement(); + if (headerValue != null) + proxyRequest.header(headerName, headerValue); + } + } + + // Force the Host header if configured + if (_hostHeader != null) + proxyRequest.header(HttpHeader.HOST, _hostHeader); + } + + protected Set findConnectionHeaders(HttpServletRequest clientRequest) + { + // Any header listed by the Connection header must be removed: + // http://tools.ietf.org/html/rfc7230#section-6.1. + Set hopHeaders = null; + Enumeration connectionHeaders = clientRequest.getHeaders(HttpHeader.CONNECTION.asString()); + while (connectionHeaders.hasMoreElements()) + { + String value = connectionHeaders.nextElement(); + String[] values = value.split(","); + for (String name : values) + { + name = name.trim().toLowerCase(Locale.ENGLISH); + if (hopHeaders == null) + hopHeaders = new HashSet<>(); + hopHeaders.add(name); + } + } + return hopHeaders; + } + + protected void addProxyHeaders(HttpServletRequest clientRequest, Request proxyRequest) + { + addViaHeader(proxyRequest); + addXForwardedHeaders(clientRequest, proxyRequest); + } + + protected void addViaHeader(Request proxyRequest) + { + proxyRequest.header(HttpHeader.VIA, "http/1.1 " + getViaHost()); + } + + protected void addXForwardedHeaders(HttpServletRequest clientRequest, Request proxyRequest) + { + proxyRequest.header(HttpHeader.X_FORWARDED_FOR, clientRequest.getRemoteAddr()); + proxyRequest.header(HttpHeader.X_FORWARDED_PROTO, clientRequest.getScheme()); + proxyRequest.header(HttpHeader.X_FORWARDED_HOST, clientRequest.getHeader(HttpHeader.HOST.asString())); + proxyRequest.header(HttpHeader.X_FORWARDED_SERVER, clientRequest.getLocalName()); + } + + protected void sendProxyRequest(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest) + { + if (_log.isDebugEnabled()) + { + StringBuilder builder = new StringBuilder(clientRequest.getMethod()); + builder.append(" ").append(clientRequest.getRequestURI()); + String query = clientRequest.getQueryString(); + if (query != null) + builder.append("?").append(query); + builder.append(" ").append(clientRequest.getProtocol()).append(System.lineSeparator()); + for (Enumeration headerNames = clientRequest.getHeaderNames(); headerNames.hasMoreElements();) + { + String headerName = headerNames.nextElement(); + builder.append(headerName).append(": "); + for (Enumeration headerValues = clientRequest.getHeaders(headerName); headerValues.hasMoreElements();) + { + String headerValue = headerValues.nextElement(); + if (headerValue != null) + builder.append(headerValue); + if (headerValues.hasMoreElements()) + builder.append(","); + } + builder.append(System.lineSeparator()); + } + builder.append(System.lineSeparator()); + + _log.debug("{} proxying to upstream:{}{}{}{}", + getRequestId(clientRequest), + System.lineSeparator(), + builder, + proxyRequest, + System.lineSeparator(), + proxyRequest.getHeaders().toString().trim()); + } + + proxyRequest.send(newProxyResponseListener(clientRequest, proxyResponse)); + } + + protected abstract Response.CompleteListener newProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse); + + protected void onClientRequestFailure(HttpServletRequest clientRequest, Request proxyRequest, HttpServletResponse proxyResponse, Throwable failure) + { + boolean aborted = proxyRequest.abort(failure); + if (!aborted) + { + proxyResponse.setStatus(500); + clientRequest.getAsyncContext().complete(); + } + } + + protected void onServerResponseHeaders(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + for (HttpField field : serverResponse.getHeaders()) + { + String headerName = field.getName(); + String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); + if (HOP_HEADERS.contains(lowerHeaderName)) + continue; + + String newHeaderValue = filterServerResponseHeader(clientRequest, headerName, field.getValue()); + if (newHeaderValue == null || newHeaderValue.trim().length() == 0) + continue; + + proxyResponse.addHeader(headerName, newHeaderValue); + } + + if (_log.isDebugEnabled()) + { + StringBuilder builder = new StringBuilder(System.lineSeparator()); + builder.append(clientRequest.getProtocol()).append(" ").append(proxyResponse.getStatus()) + .append(" ").append(serverResponse.getReason()).append(System.lineSeparator()); + for (String headerName : proxyResponse.getHeaderNames()) + { + builder.append(headerName).append(": "); + for (Iterator headerValues = proxyResponse.getHeaders(headerName).iterator(); headerValues.hasNext(); ) + { + String headerValue = headerValues.next(); + if (headerValue != null) + builder.append(headerValue); + if (headerValues.hasNext()) + builder.append(","); + } + builder.append(System.lineSeparator()); + } + _log.debug("{} proxying to downstream:{}{}{}{}{}", + getRequestId(clientRequest), + System.lineSeparator(), + serverResponse, + System.lineSeparator(), + serverResponse.getHeaders().toString().trim(), + System.lineSeparator(), + builder); + } + } + + protected String filterServerResponseHeader(HttpServletRequest clientRequest, String headerName, String headerValue) + { + return headerValue; + } + + protected void onProxyResponseSuccess(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + if (_log.isDebugEnabled()) + _log.debug("{} proxying successful", getRequestId(clientRequest)); + + AsyncContext asyncContext = clientRequest.getAsyncContext(); + asyncContext.complete(); + } + + protected void onProxyResponseFailure(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse, Throwable failure) + { + if (_log.isDebugEnabled()) + _log.debug(getRequestId(clientRequest) + " proxying failed", failure); + + if (proxyResponse.isCommitted()) + { + try + { + // Use Jetty specific behavior to close connection. + proxyResponse.sendError(-1); + AsyncContext asyncContext = clientRequest.getAsyncContext(); + asyncContext.complete(); + } + catch (IOException x) + { + if (_log.isDebugEnabled()) + _log.debug(getRequestId(clientRequest) + " could not close the connection", failure); + } + } + else + { + proxyResponse.resetBuffer(); + if (failure instanceof TimeoutException) + proxyResponse.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT); + else + proxyResponse.setStatus(HttpServletResponse.SC_BAD_GATEWAY); + proxyResponse.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString()); + AsyncContext asyncContext = clientRequest.getAsyncContext(); + asyncContext.complete(); + } + } + + protected int getRequestId(HttpServletRequest clientRequest) + { + return System.identityHashCode(clientRequest); + } +} diff --git a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AsyncMiddleManServlet.java b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AsyncMiddleManServlet.java new file mode 100644 index 00000000000..0bee62ac665 --- /dev/null +++ b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/AsyncMiddleManServlet.java @@ -0,0 +1,700 @@ +// +// ======================================================================== +// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.proxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.zip.GZIPOutputStream; +import javax.servlet.AsyncContext; +import javax.servlet.ReadListener; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.eclipse.jetty.client.ContentDecoder; +import org.eclipse.jetty.client.GZIPContentDecoder; +import org.eclipse.jetty.client.api.ContentProvider; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.client.api.Result; +import org.eclipse.jetty.client.util.DeferredContentProvider; +import org.eclipse.jetty.http.HttpHeader; +import org.eclipse.jetty.http.HttpVersion; +import org.eclipse.jetty.io.RuntimeIOException; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.IteratingCallback; + +public class AsyncMiddleManServlet extends AbstractProxyServlet +{ + private static final String CLIENT_TRANSFORMER = AsyncMiddleManServlet.class.getName() + ".clientTransformer"; + private static final String SERVER_TRANSFORMER = AsyncMiddleManServlet.class.getName() + ".serverTransformer"; + + @Override + protected void service(HttpServletRequest clientRequest, HttpServletResponse proxyResponse) throws ServletException, IOException + { + String rewrittenTarget = rewriteTarget(clientRequest); + if (_log.isDebugEnabled()) + { + StringBuffer target = clientRequest.getRequestURL(); + if (clientRequest.getQueryString() != null) + target.append("?").append(clientRequest.getQueryString()); + _log.debug("{} rewriting: {} -> {}", getRequestId(clientRequest), target, rewrittenTarget); + } + if (rewrittenTarget == null) + { + onProxyRewriteFailed(clientRequest, proxyResponse); + return; + } + + final Request proxyRequest = getHttpClient().newRequest(rewrittenTarget) + .method(clientRequest.getMethod()) + .version(HttpVersion.fromString(clientRequest.getProtocol())); + + boolean hasContent = hasContent(clientRequest); + + copyHeaders(clientRequest, proxyRequest); + + addProxyHeaders(clientRequest, proxyRequest); + + final AsyncContext asyncContext = clientRequest.startAsync(); + // We do not timeout the continuation, but the proxy request. + asyncContext.setTimeout(0); + proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS); + + // If there is content, the send of the proxy request + // is delayed and performed when the content arrives, + // to allow optimization of the Content-Length header. + if (hasContent) + proxyRequest.content(newProxyContentProvider(clientRequest, proxyResponse, proxyRequest)); + else + sendProxyRequest(clientRequest, proxyResponse, proxyRequest); + } + + protected ContentProvider newProxyContentProvider(final HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest) throws IOException + { + ServletInputStream input = clientRequest.getInputStream(); + DeferredContentProvider provider = new DeferredContentProvider() + { + @Override + public boolean offer(ByteBuffer buffer, Callback callback) + { + if (_log.isDebugEnabled()) + _log.debug("{} proxying content to upstream: {} bytes", getRequestId(clientRequest), buffer.remaining()); + return super.offer(buffer, callback); + } + }; + input.setReadListener(newProxyReadListener(clientRequest, proxyResponse, proxyRequest, provider)); + return provider; + } + + protected ReadListener newProxyReadListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest, DeferredContentProvider provider) + { + return new ProxyReader(clientRequest, proxyResponse, proxyRequest, provider); + } + + protected ProxyWriter newProxyWriteListener(HttpServletRequest clientRequest, Response proxyResponse) + { + return new ProxyWriter(clientRequest, proxyResponse); + } + + protected Response.CompleteListener newProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse) + { + return new ProxyResponseListener(clientRequest, proxyResponse); + } + + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return ContentTransformer.IDENTITY; + } + + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return ContentTransformer.IDENTITY; + } + + int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException + { + return input.read(buffer); + } + + void writeProxyResponseContent(ServletOutputStream output, ByteBuffer content) throws IOException + { + write(output, content); + } + + private static void write(OutputStream output, ByteBuffer content) throws IOException + { + int length = content.remaining(); + int offset = 0; + byte[] buffer; + if (content.hasArray()) + { + offset = content.arrayOffset(); + buffer = content.array(); + } + else + { + buffer = new byte[length]; + content.get(buffer); + } + output.write(buffer, offset, length); + } + + protected class ProxyReader extends IteratingCallback implements ReadListener + { + private final Callback failer = new Adapter() + { + @Override + public void failed(Throwable x) + { + onError(x); + } + }; + private final byte[] buffer = new byte[getHttpClient().getRequestBufferSize()]; + private final List buffers = new ArrayList<>(); + private final HttpServletRequest clientRequest; + private final HttpServletResponse proxyResponse; + private final Request proxyRequest; + private final DeferredContentProvider provider; + private final int contentLength; + private int length; + + protected ProxyReader(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Request proxyRequest, DeferredContentProvider provider) + { + this.clientRequest = clientRequest; + this.proxyResponse = proxyResponse; + this.proxyRequest = proxyRequest; + this.provider = provider; + this.contentLength = clientRequest.getContentLength(); + } + + @Override + public void onDataAvailable() throws IOException + { + iterate(); + } + + @Override + public void onAllDataRead() throws IOException + { + if (!provider.isClosed()) + process(BufferUtil.EMPTY_BUFFER, failer, true); + + if (_log.isDebugEnabled()) + _log.debug("{} proxying content to upstream completed", getRequestId(clientRequest)); + } + + @Override + public void onError(Throwable t) + { + onClientRequestFailure(clientRequest, proxyRequest, proxyResponse, t); + } + + @Override + protected Action process() throws Exception + { + ServletInputStream input = clientRequest.getInputStream(); + while (input.isReady() && !input.isFinished()) + { + int read = readClientRequestContent(input, buffer); + if (_log.isDebugEnabled()) + _log.debug("{} asynchronous read {} bytes on {}", getRequestId(clientRequest), read, input); + + if (contentLength > 0 && read > 0) + length += read; + + ByteBuffer content = read > 0 ? ByteBuffer.wrap(buffer, 0, read) : BufferUtil.EMPTY_BUFFER; + boolean finished = read < 0 || length == contentLength; + process(content, this, finished); + + if (read > 0) + return Action.SCHEDULED; + } + + if (input.isFinished()) + { + if (_log.isDebugEnabled()) + _log.debug("{} asynchronous read complete on {}", getRequestId(clientRequest), input); + return Action.SUCCEEDED; + } + else + { + if (_log.isDebugEnabled()) + _log.debug("{} asynchronous read pending on {}", getRequestId(clientRequest), input); + return Action.IDLE; + } + } + + private void process(ByteBuffer content, Callback callback, boolean finished) throws IOException + { + ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(CLIENT_TRANSFORMER); + boolean committed = transformer != null; + if (transformer == null) + { + transformer = newClientRequestContentTransformer(clientRequest, proxyRequest); + clientRequest.setAttribute(CLIENT_TRANSFORMER, transformer); + } + + if (content.hasRemaining() || finished) + { + int contentBytes = content.remaining(); + + transformer.transform(content, finished, buffers); + + int newContentBytes = 0; + int size = buffers.size(); + for (int i = 0; i < size; ++i) + { + ByteBuffer buffer = buffers.get(i); + newContentBytes += buffer.remaining(); + provider.offer(buffer, i == size - 1 ? callback : failer); + } + buffers.clear(); + + if (finished) + provider.close(); + + if (_log.isDebugEnabled()) + _log.debug("{} upstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes); + + if (!committed) + { + proxyRequest.header(HttpHeader.CONTENT_LENGTH, null); + sendProxyRequest(clientRequest, proxyResponse, proxyRequest); + } + + if (size == 0) + succeeded(); + } + } + + @Override + protected void onCompleteFailure(Throwable x) + { + onError(x); + } + } + + protected class ProxyResponseListener extends Response.Listener.Adapter + { + private final String WRITE_LISTENER_ATTRIBUTE = AsyncMiddleManServlet.class.getName() + ".writeListener"; + + private final List buffers = new ArrayList<>(); + private final AtomicBoolean complete = new AtomicBoolean(); + private final HttpServletRequest clientRequest; + private final HttpServletResponse proxyResponse; + private boolean hasContent; + private long contentLength; + private long length; + + protected ProxyResponseListener(HttpServletRequest clientRequest, HttpServletResponse proxyResponse) + { + this.clientRequest = clientRequest; + this.proxyResponse = proxyResponse; + } + + @Override + public void onBegin(Response serverResponse) + { + proxyResponse.setStatus(serverResponse.getStatus()); + } + + @Override + public void onHeaders(Response serverResponse) + { + contentLength = serverResponse.getHeaders().getLongField(HttpHeader.CONTENT_LENGTH.asString()); + onServerResponseHeaders(clientRequest, proxyResponse, serverResponse); + } + + @Override + public void onContent(final Response serverResponse, ByteBuffer content, final Callback callback) + { + try + { + int contentBytes = content.remaining(); + if (_log.isDebugEnabled()) + _log.debug("{} received server content: {} bytes", getRequestId(clientRequest), contentBytes); + + hasContent = true; + + ProxyWriter proxyWriter = (ProxyWriter)clientRequest.getAttribute(WRITE_LISTENER_ATTRIBUTE); + boolean committed = proxyWriter != null; + if (proxyWriter == null) + { + proxyWriter = newProxyWriteListener(clientRequest, serverResponse); + clientRequest.setAttribute(WRITE_LISTENER_ATTRIBUTE, proxyWriter); + } + + ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(SERVER_TRANSFORMER); + if (transformer == null) + { + transformer = newServerResponseContentTransformer(clientRequest, proxyResponse, serverResponse); + clientRequest.setAttribute(SERVER_TRANSFORMER, transformer); + } + + length += contentBytes; + + boolean finished = contentLength > 0 && length == contentLength; + transformer.transform(content, finished, buffers); + + int newContentBytes = 0; + int size = buffers.size(); + for (int i = 0; i < size; ++i) + { + ByteBuffer buffer = buffers.get(i); + newContentBytes += buffer.remaining(); + proxyWriter.offer(buffer, i == size - 1 ? callback : Callback.Adapter.INSTANCE); + } + buffers.clear(); + if (_log.isDebugEnabled()) + _log.debug("{} downstream content transformation {} -> {} bytes", getRequestId(clientRequest), contentBytes, newContentBytes); + + if (committed) + { + proxyWriter.onWritePossible(); + } + else + { + if (contentLength > 0) + proxyResponse.setContentLength(-1); + + // Setting the WriteListener triggers an invocation to + // onWritePossible(), possibly on a different thread. + proxyResponse.getOutputStream().setWriteListener(proxyWriter); + } + + if (size == 0) + callback.succeeded(); + } + catch (Throwable x) + { + callback.failed(x); + } + } + + @Override + public void onSuccess(final Response serverResponse) + { + try + { + // If we had unknown length content, we need to call the + // transformer to signal that the content is finished. + if (contentLength < 0 && hasContent) + { + ProxyWriter proxyWriter = (ProxyWriter)clientRequest.getAttribute(WRITE_LISTENER_ATTRIBUTE); + ContentTransformer transformer = (ContentTransformer)clientRequest.getAttribute(SERVER_TRANSFORMER); + + transformer.transform(BufferUtil.EMPTY_BUFFER, true, buffers); + + long newContentBytes = 0; + int size = buffers.size(); + for (int i = 0; i < size; ++i) + { + ByteBuffer buffer = buffers.get(i); + newContentBytes += buffer.remaining(); + proxyWriter.offer(buffer, i == size - 1 ? new Callback.Adapter() + { + @Override + public void failed(Throwable x) + { + if (complete.compareAndSet(false, true)) + onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x); + } + } : Callback.Adapter.INSTANCE); + } + buffers.clear(); + if (_log.isDebugEnabled()) + _log.debug("{} downstream content transformation to {} bytes", getRequestId(clientRequest), newContentBytes); + + proxyWriter.onWritePossible(); + } + } + catch (Throwable x) + { + if (complete.compareAndSet(false, true)) + onProxyResponseFailure(clientRequest, proxyResponse, serverResponse, x); + } + } + + @Override + public void onComplete(Result result) + { + if (complete.compareAndSet(false, true)) + { + if (result.isSucceeded()) + onProxyResponseSuccess(clientRequest, proxyResponse, result.getResponse()); + else + onProxyResponseFailure(clientRequest, proxyResponse, result.getResponse(), result.getFailure()); + } + if (_log.isDebugEnabled()) + _log.debug("{} proxying complete", getRequestId(clientRequest)); + } + } + + protected class ProxyWriter implements WriteListener + { + private final Queue chunks = new ArrayDeque<>(); + private final HttpServletRequest clientRequest; + private final Response serverResponse; + private AsyncChunk chunk; + private boolean writePending; + + protected ProxyWriter(HttpServletRequest clientRequest, Response serverResponse) + { + this.clientRequest = clientRequest; + this.serverResponse = serverResponse; + } + + public boolean offer(ByteBuffer content, Callback callback) + { + if (_log.isDebugEnabled()) + _log.debug("{} proxying content to downstream: {} bytes", getRequestId(clientRequest), content.remaining()); + return chunks.offer(new AsyncChunk(content, callback)); + } + + @Override + public void onWritePossible() throws IOException + { + ServletOutputStream output = clientRequest.getAsyncContext().getResponse().getOutputStream(); + while (true) + { + if (writePending) + { + // The write was pending but is now complete. + writePending = false; + if (_log.isDebugEnabled()) + _log.debug("{} pending async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output); + if (succeed(chunk.callback)) + break; + } + else + { + chunk = chunks.poll(); + if (chunk == null) + break; + + writeProxyResponseContent(output, chunk.buffer); + + if (output.isReady()) + { + if (_log.isDebugEnabled()) + _log.debug("{} async write complete of {} bytes on {}", getRequestId(clientRequest), chunk.length, output); + if (succeed(chunk.callback)) + break; + } + else + { + writePending = true; + if (_log.isDebugEnabled()) + _log.debug("{} async write pending of {} bytes on {}", getRequestId(clientRequest), chunk.length, output); + break; + } + } + } + } + + private boolean succeed(Callback callback) + { + // Succeeding the callback may cause to reenter in onWritePossible() + // because typically the callback is the one that controls whether the + // content received from the server has been consumed, so succeeding + // the callback causes more content to be received from the server, + // and hence more to be written to the client by onWritePossible(). + // A reentrant call to onWritePossible() performs another write, + // which may remain pending, which means that the reentrant call + // to onWritePossible() returns all the way back to just after the + // succeed of the callback. There, we cannot just loop attempting + // write, but we need to check whether we are still write pending. + callback.succeeded(); + return writePending; + } + + @Override + public void onError(Throwable failure) + { + AsyncChunk chunk = this.chunk; + if (chunk != null) + chunk.callback.failed(failure); + else + serverResponse.abort(failure); + } + } + + // TODO: coalesce this class with the one from DeferredContentProvider ? + private static class AsyncChunk + { + private final ByteBuffer buffer; + private final Callback callback; + private final int length; + + private AsyncChunk(ByteBuffer buffer, Callback callback) + { + this.buffer = Objects.requireNonNull(buffer); + this.callback = Objects.requireNonNull(callback); + this.length = buffer.remaining(); + } + } + + /** + *

Allows applications to transform upstream and downstream content.

+ *

Typical use cases of transformations are URL rewriting of HTML anchors + * (where the value of the href attribute of <a> elements + * is modified by the proxy), field renaming of JSON documents, etc.

+ *

Applications should override {@link #newClientRequestContentTransformer(HttpServletRequest, Request)} + * and/or {@link #newServerResponseContentTransformer(HttpServletRequest, HttpServletResponse, Response)} + * to provide the transformer implementation.

+ */ + public interface ContentTransformer + { + /** + * The identity transformer that does not perform any transformation. + */ + public static final ContentTransformer IDENTITY = new IdentityContentTransformer(); + + /** + *

Transforms the given input byte buffers into (possibly multiple) byte buffers.

+ *

The transformation must happen synchronously in the context of a call + * to this method (it is not supported to perform the transformation in another + * thread spawned during the call to this method). + * The transformation may happen or not, depending on the transformer implementation. + * For example, a buffering transformer may buffer the input aside, and only + * perform the transformation when the whole input is provided (by looking at the + * {@code finished} flag).

+ *

The input buffer will be cleared and reused after the call to this method. + * Implementations that want to buffer aside the input (or part of it) must copy + * the input bytes that they want to buffer.

+ *

Typical implementations:

+ *
+         * // Identity transformation (no transformation, the input is copied to the output)
+         * public void transform(ByteBuffer input, boolean finished, List output)
+         * {
+         *     output.add(input);
+         * }
+         *
+         * // Discard transformation (all input is discarded)
+         * public void transform(ByteBuffer input, boolean finished, List output)
+         * {
+         *     // Empty
+         * }
+         *
+         * // Buffering identity transformation (all input is buffered aside until it is finished)
+         * public void transform(ByteBuffer input, boolean finished, List output)
+         * {
+         *     ByteBuffer copy = ByteBuffer.allocate(input.remaining());
+         *     copy.put(input).flip();
+         *     store(copy);
+         *
+         *     if (finished)
+         *     {
+         *         List<ByteBuffer> copies = retrieve();
+         *         output.addAll(copies);
+         *     }
+         * }
+         * 
+ * + * @param input the input content to transform (may be of length zero) + * @param finished whether the input content is finished or more will come + * @param output where to put the transformed output content + * @throws IOException in case of transformation failures + */ + public void transform(ByteBuffer input, boolean finished, List output) throws IOException; + } + + private static class IdentityContentTransformer implements ContentTransformer + { + @Override + public void transform(ByteBuffer input, boolean finished, List output) + { + output.add(input); + } + } + + public static class GZIPContentTransformer implements ContentTransformer + { + private final List buffers = new ArrayList<>(2); + private final ContentDecoder decoder = new GZIPContentDecoder(); + private final ContentTransformer transformer; + private final ByteArrayOutputStream out; + private final GZIPOutputStream gzipOut; + + public GZIPContentTransformer(ContentTransformer transformer) + { + try + { + this.transformer = transformer; + this.out = new ByteArrayOutputStream(); + this.gzipOut = new GZIPOutputStream(out); + } + catch (IOException x) + { + throw new RuntimeIOException(x); + } + } + + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + if (!input.hasRemaining()) + { + if (finished) + transformer.transform(input, true, buffers); + } + else + { + while (input.hasRemaining()) + { + ByteBuffer decoded = decoder.decode(input); + if (decoded.hasRemaining()) + transformer.transform(decoded, finished && !input.hasRemaining(), buffers); + } + } + + if (!buffers.isEmpty()) + { + ByteBuffer result = gzip(buffers, finished); + buffers.clear(); + output.add(result); + } + } + + private ByteBuffer gzip(List buffers, boolean finished) throws IOException + { + for (ByteBuffer buffer : buffers) + write(gzipOut, buffer); + if (finished) + gzipOut.close(); + byte[] gzipBytes = out.toByteArray(); + out.reset(); + return ByteBuffer.wrap(gzipBytes); + } + } +} diff --git a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ProxyServlet.java b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ProxyServlet.java index 3d9c34c65b1..b92d2fa8f37 100644 --- a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ProxyServlet.java +++ b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ProxyServlet.java @@ -20,24 +20,14 @@ package org.eclipse.jetty.proxy; import java.io.IOException; import java.io.InputStream; -import java.net.InetAddress; import java.net.URI; -import java.net.UnknownHostException; import java.nio.ByteBuffer; -import java.util.Enumeration; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Locale; -import java.util.Set; -import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import javax.servlet.AsyncContext; import javax.servlet.ServletConfig; import javax.servlet.ServletContext; import javax.servlet.ServletException; import javax.servlet.UnavailableException; -import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -47,15 +37,8 @@ import org.eclipse.jetty.client.api.Request; import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.client.util.InputStreamContentProvider; -import org.eclipse.jetty.http.HttpField; -import org.eclipse.jetty.http.HttpHeader; -import org.eclipse.jetty.http.HttpHeaderValue; import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.util.Callback; -import org.eclipse.jetty.util.HttpCookieStore; -import org.eclipse.jetty.util.log.Log; -import org.eclipse.jetty.util.log.Logger; -import org.eclipse.jetty.util.thread.QueuedThreadPool; /** * Asynchronous ProxyServlet. @@ -80,308 +63,8 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool; * * @see ConnectHandler */ -public class ProxyServlet extends HttpServlet +public class ProxyServlet extends AbstractProxyServlet { - private static final Set HOP_HEADERS = new HashSet<>(); - static - { - HOP_HEADERS.add("connection"); - HOP_HEADERS.add("keep-alive"); - HOP_HEADERS.add("proxy-authorization"); - HOP_HEADERS.add("proxy-authenticate"); - HOP_HEADERS.add("proxy-connection"); - HOP_HEADERS.add("transfer-encoding"); - HOP_HEADERS.add("te"); - HOP_HEADERS.add("trailer"); - HOP_HEADERS.add("upgrade"); - } - - private final Set _whiteList = new HashSet<>(); - private final Set _blackList = new HashSet<>(); - - protected Logger _log; - private String _hostHeader; - private String _viaHost; - private HttpClient _client; - private long _timeout; - - @Override - public void init() throws ServletException - { - _log = createLogger(); - - ServletConfig config = getServletConfig(); - - _hostHeader = config.getInitParameter("hostHeader"); - - _viaHost = config.getInitParameter("viaHost"); - if (_viaHost == null) - _viaHost = viaHost(); - - try - { - _client = createHttpClient(); - - // Put the HttpClient in the context to leverage ContextHandler.MANAGED_ATTRIBUTES - getServletContext().setAttribute(config.getServletName() + ".HttpClient", _client); - - String whiteList = config.getInitParameter("whiteList"); - if (whiteList != null) - getWhiteListHosts().addAll(parseList(whiteList)); - - String blackList = config.getInitParameter("blackList"); - if (blackList != null) - getBlackListHosts().addAll(parseList(blackList)); - } - catch (Exception e) - { - throw new ServletException(e); - } - } - - public String getViaHost() - { - return _viaHost; - } - - public long getTimeout() - { - return _timeout; - } - - public void setTimeout(long timeout) - { - this._timeout = timeout; - } - - public Set getWhiteListHosts() - { - return _whiteList; - } - - public Set getBlackListHosts() - { - return _blackList; - } - - protected static String viaHost() - { - try - { - return InetAddress.getLocalHost().getHostName(); - } - catch (UnknownHostException x) - { - return "localhost"; - } - } - - protected HttpClient getHttpClient() - { - return _client; - } - - /** - * @return a logger instance with a name derived from this servlet's name. - */ - protected Logger createLogger() - { - String servletName = getServletConfig().getServletName(); - servletName = servletName.replace('-', '.'); - if ((getClass().getPackage() != null) && !servletName.startsWith(getClass().getPackage().getName())) - { - servletName = getClass().getName() + "." + servletName; - } - return Log.getLogger(servletName); - } - - public void destroy() - { - try - { - _client.stop(); - } - catch (Exception x) - { - if (_log.isDebugEnabled()) - _log.debug(x); - } - } - - /** - * Creates a {@link HttpClient} instance, configured with init parameters of this servlet. - *

- * The init parameters used to configure the {@link HttpClient} instance are: - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - *
init-paramdefaultdescription
maxThreads256The max number of threads of HttpClient's Executor. If not set, or set to the value of "-", then the - * Jetty server thread pool will be used.
maxConnections32768The max number of connections per destination, see {@link HttpClient#setMaxConnectionsPerDestination(int)}
idleTimeout30000The idle timeout in milliseconds, see {@link HttpClient#setIdleTimeout(long)}
timeout60000The total timeout in milliseconds, see {@link Request#timeout(long, TimeUnit)}
requestBufferSizeHttpClient's defaultThe request buffer size, see {@link HttpClient#setRequestBufferSize(int)}
responseBufferSizeHttpClient's defaultThe response buffer size, see {@link HttpClient#setResponseBufferSize(int)}
- * - * @return a {@link HttpClient} configured from the {@link #getServletConfig() servlet configuration} - * @throws ServletException if the {@link HttpClient} cannot be created - */ - protected HttpClient createHttpClient() throws ServletException - { - ServletConfig config = getServletConfig(); - - HttpClient client = newHttpClient(); - - // Redirects must be proxied as is, not followed - client.setFollowRedirects(false); - - // Must not store cookies, otherwise cookies of different clients will mix - client.setCookieStore(new HttpCookieStore.Empty()); - - Executor executor; - String value = config.getInitParameter("maxThreads"); - if (value == null || "-".equals(value)) - { - executor = (Executor)getServletContext().getAttribute("org.eclipse.jetty.server.Executor"); - if (executor==null) - throw new IllegalStateException("No server executor for proxy"); - } - else - { - QueuedThreadPool qtp= new QueuedThreadPool(Integer.parseInt(value)); - String servletName = config.getServletName(); - int dot = servletName.lastIndexOf('.'); - if (dot >= 0) - servletName = servletName.substring(dot + 1); - qtp.setName(servletName); - executor=qtp; - } - - client.setExecutor(executor); - - value = config.getInitParameter("maxConnections"); - if (value == null) - value = "256"; - client.setMaxConnectionsPerDestination(Integer.parseInt(value)); - - value = config.getInitParameter("idleTimeout"); - if (value == null) - value = "30000"; - client.setIdleTimeout(Long.parseLong(value)); - - value = config.getInitParameter("timeout"); - if (value == null) - value = "60000"; - _timeout = Long.parseLong(value); - - value = config.getInitParameter("requestBufferSize"); - if (value != null) - client.setRequestBufferSize(Integer.parseInt(value)); - - value = config.getInitParameter("responseBufferSize"); - if (value != null) - client.setResponseBufferSize(Integer.parseInt(value)); - - try - { - client.start(); - - // Content must not be decoded, otherwise the client gets confused - client.getContentDecoderFactories().clear(); - - return client; - } - catch (Exception x) - { - throw new ServletException(x); - } - } - - /** - * @return a new HttpClient instance - */ - protected HttpClient newHttpClient() - { - return new HttpClient(); - } - - private Set parseList(String list) - { - Set result = new HashSet<>(); - String[] hosts = list.split(","); - for (String host : hosts) - { - host = host.trim(); - if (host.length() == 0) - continue; - result.add(host); - } - return result; - } - - /** - * Checks the given {@code host} and {@code port} against whitelist and blacklist. - * - * @param host the host to check - * @param port the port to check - * @return true if it is allowed to be proxy to the given host and port - */ - public boolean validateDestination(String host, int port) - { - String hostPort = host + ":" + port; - if (!_whiteList.isEmpty()) - { - if (!_whiteList.contains(hostPort)) - { - if (_log.isDebugEnabled()) - _log.debug("Host {}:{} not whitelisted", host, port); - return false; - } - } - if (!_blackList.isEmpty()) - { - if (_blackList.contains(hostPort)) - { - if (_log.isDebugEnabled()) - _log.debug("Host {}:{} blacklisted", host, port); - return false; - } - } - return true; - } - @Override protected void service(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { @@ -404,107 +87,25 @@ public class ProxyServlet extends HttpServlet return; } - final Request proxyRequest = _client.newRequest(rewrittenURI) + final Request proxyRequest = getHttpClient().newRequest(rewrittenURI) .method(request.getMethod()) .version(HttpVersion.fromString(request.getProtocol())); - // Copy headers. + copyHeaders(request, proxyRequest); - // Any header listed by the Connection header must be removed: - // http://tools.ietf.org/html/rfc7230#section-6.1. - Set hopHeaders = null; - Enumeration connectionHeaders = request.getHeaders(HttpHeader.CONNECTION.asString()); - while (connectionHeaders.hasMoreElements()) - { - String value = connectionHeaders.nextElement(); - String[] values = value.split(","); - for (String name : values) - { - name = name.trim().toLowerCase(Locale.ENGLISH); - if (hopHeaders == null) - hopHeaders = new HashSet<>(); - hopHeaders.add(name); - } - } - - boolean hasContent = request.getContentLength() > 0 || request.getContentType() != null; - for (Enumeration headerNames = request.getHeaderNames(); headerNames.hasMoreElements();) - { - String headerName = headerNames.nextElement(); - String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); - - if (HttpHeader.TRANSFER_ENCODING.is(headerName)) - hasContent = true; - - if (_hostHeader != null && HttpHeader.HOST.is(headerName)) - continue; - - // Remove hop-by-hop headers. - if (HOP_HEADERS.contains(lowerHeaderName)) - continue; - if (hopHeaders != null && hopHeaders.contains(lowerHeaderName)) - continue; - - for (Enumeration headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();) - { - String headerValue = headerValues.nextElement(); - if (headerValue != null) - proxyRequest.header(headerName, headerValue); - } - } - - // Force the Host header if configured - if (_hostHeader != null) - proxyRequest.header(HttpHeader.HOST, _hostHeader); - - // Add proxy headers - addViaHeader(proxyRequest); - addXForwardedHeaders(proxyRequest, request); + addProxyHeaders(request, proxyRequest); final AsyncContext asyncContext = request.startAsync(); // We do not timeout the continuation, but the proxy request asyncContext.setTimeout(0); proxyRequest.timeout(getTimeout(), TimeUnit.MILLISECONDS); - if (hasContent) + if (hasContent(request)) proxyRequest.content(proxyRequestContent(proxyRequest, request)); customizeProxyRequest(proxyRequest, request); - if (_log.isDebugEnabled()) - { - StringBuilder builder = new StringBuilder(request.getMethod()); - builder.append(" ").append(request.getRequestURI()); - String query = request.getQueryString(); - if (query != null) - builder.append("?").append(query); - builder.append(" ").append(request.getProtocol()).append("\r\n"); - for (Enumeration headerNames = request.getHeaderNames(); headerNames.hasMoreElements();) - { - String headerName = headerNames.nextElement(); - builder.append(headerName).append(": "); - for (Enumeration headerValues = request.getHeaders(headerName); headerValues.hasMoreElements();) - { - String headerValue = headerValues.nextElement(); - if (headerValue != null) - builder.append(headerValue); - if (headerValues.hasMoreElements()) - builder.append(","); - } - builder.append("\r\n"); - } - builder.append("\r\n"); - - _log.debug("{} proxying to upstream:{}{}{}{}", - requestId, - System.lineSeparator(), - builder, - proxyRequest, - System.lineSeparator(), - proxyRequest.getHeaders().toString().trim()); - } - - proxyRequest.send(newProxyResponseListener(request, response)); + sendProxyRequest(request, response, proxyRequest); } protected ContentProvider proxyRequestContent(final Request proxyRequest, final HttpServletRequest request) throws IOException @@ -524,39 +125,29 @@ public class ProxyServlet extends HttpServlet proxyRequest.abort(failure); } + /** + * @deprecated use {@link #onProxyRewriteFailed(HttpServletRequest, HttpServletResponse)} + */ + @Deprecated protected void onRewriteFailed(HttpServletRequest request, HttpServletResponse response) throws IOException { - response.sendError(HttpServletResponse.SC_FORBIDDEN); - } - - protected Request addViaHeader(Request proxyRequest) - { - return proxyRequest.header(HttpHeader.VIA, "http/1.1 " + getViaHost()); - } - - protected void addXForwardedHeaders(Request proxyRequest, HttpServletRequest request) - { - proxyRequest.header(HttpHeader.X_FORWARDED_FOR, request.getRemoteAddr()); - proxyRequest.header(HttpHeader.X_FORWARDED_PROTO, request.getScheme()); - proxyRequest.header(HttpHeader.X_FORWARDED_HOST, request.getHeader(HttpHeader.HOST.asString())); - proxyRequest.header(HttpHeader.X_FORWARDED_SERVER, request.getLocalName()); + onProxyRewriteFailed(request, response); } + /** + * @deprecated use {@link #onServerResponseHeaders(HttpServletRequest, HttpServletResponse, Response)} + */ + @Deprecated protected void onResponseHeaders(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) { - for (HttpField field : proxyResponse.getHeaders()) - { - String headerName = field.getName(); - String lowerHeaderName = headerName.toLowerCase(Locale.ENGLISH); - if (HOP_HEADERS.contains(lowerHeaderName)) - continue; + onServerResponseHeaders(request, response, proxyResponse); + } - String newHeaderValue = filterResponseHeader(request, headerName, field.getValue()); - if (newHeaderValue == null || newHeaderValue.trim().length() == 0) - continue; - - response.addHeader(headerName, newHeaderValue); - } + // TODO: remove in Jetty 9.3, only here for backward compatibility. + @Override + protected String filterServerResponseHeader(HttpServletRequest clientRequest, String headerName, String headerValue) + { + return filterResponseHeader(clientRequest, headerName, headerValue); } protected void onResponseContent(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, byte[] buffer, int offset, int length, Callback callback) @@ -574,71 +165,38 @@ public class ProxyServlet extends HttpServlet } } + /** + * @deprecated Use {@link #onProxyResponseSuccess(HttpServletRequest, HttpServletResponse, Response)} + */ + @Deprecated protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) { - if (_log.isDebugEnabled()) - _log.debug("{} proxying successful", getRequestId(request)); - AsyncContext asyncContext = request.getAsyncContext(); - asyncContext.complete(); - } - - protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure) - { - if (_log.isDebugEnabled()) - _log.debug(getRequestId(request) + " proxying failed", failure); - if (response.isCommitted()) - { - try - { - // Use Jetty specific behavior to close connection. - response.sendError(-1); - AsyncContext asyncContext = request.getAsyncContext(); - asyncContext.complete(); - } - catch (IOException x) - { - if (_log.isDebugEnabled()) - _log.debug(getRequestId(request) + " could not close the connection", failure); - } - } - else - { - response.resetBuffer(); - if (failure instanceof TimeoutException) - response.setStatus(HttpServletResponse.SC_GATEWAY_TIMEOUT); - else - response.setStatus(HttpServletResponse.SC_BAD_GATEWAY); - response.setHeader(HttpHeader.CONNECTION.asString(), HttpHeaderValue.CLOSE.asString()); - AsyncContext asyncContext = request.getAsyncContext(); - asyncContext.complete(); - } - } - - protected int getRequestId(HttpServletRequest request) - { - return System.identityHashCode(request); - } - - protected URI rewriteURI(HttpServletRequest request) - { - if (!validateDestination(request.getServerName(), request.getServerPort())) - return null; - - StringBuffer uri = request.getRequestURL(); - String query = request.getQueryString(); - if (query != null) - uri.append("?").append(query); - - return URI.create(uri.toString()); + onProxyResponseSuccess(request, response, proxyResponse); } /** - * Extension point for subclasses to customize the proxy request. - * The default implementation does nothing. - * - * @param proxyRequest the proxy request to customize - * @param request the request to be proxied + * @deprecated Use {@link #onProxyResponseFailure(HttpServletRequest, HttpServletResponse, Response, Throwable)} */ + @Deprecated + protected void onResponseFailure(HttpServletRequest request, HttpServletResponse response, Response proxyResponse, Throwable failure) + { + onProxyResponseFailure(request, response, proxyResponse, failure); + } + + /** + * @deprecated use {@link #rewriteTarget(HttpServletRequest)} + */ + @Deprecated + protected URI rewriteURI(HttpServletRequest request) + { + String newTarget = rewriteTarget(request); + return newTarget == null ? null : URI.create(newTarget); + } + + /** + * @deprecated use {@link #sendProxyRequest(HttpServletRequest, HttpServletResponse, Request)} + */ + @Deprecated protected void customizeProxyRequest(Request proxyRequest, HttpServletRequest request) { } @@ -766,33 +324,6 @@ public class ProxyServlet extends HttpServlet public void onHeaders(Response proxyResponse) { onResponseHeaders(request, response, proxyResponse); - - if (_log.isDebugEnabled()) - { - StringBuilder builder = new StringBuilder("\r\n"); - builder.append(request.getProtocol()).append(" ").append(response.getStatus()).append(" ").append(proxyResponse.getReason()).append("\r\n"); - for (String headerName : response.getHeaderNames()) - { - builder.append(headerName).append(": "); - for (Iterator headerValues = response.getHeaders(headerName).iterator(); headerValues.hasNext();) - { - String headerValue = headerValues.next(); - if (headerValue != null) - builder.append(headerValue); - if (headerValues.hasNext()) - builder.append(","); - } - builder.append("\r\n"); - } - _log.debug("{} proxying to downstream:{}{}{}{}{}", - getRequestId(request), - System.lineSeparator(), - proxyResponse, - System.lineSeparator(), - proxyResponse.getHeaders().toString().trim(), - System.lineSeparator(), - builder); - } } @Override diff --git a/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/AsyncMiddleManServletTest.java b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/AsyncMiddleManServletTest.java new file mode 100644 index 00000000000..d88a977502c --- /dev/null +++ b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/AsyncMiddleManServletTest.java @@ -0,0 +1,850 @@ +// +// ======================================================================== +// Copyright (c) 1995-2015 Mort Bay Consulting Pty. Ltd. +// ------------------------------------------------------------------------ +// All rights reserved. This program and the accompanying materials +// are made available under the terms of the Eclipse Public License v1.0 +// and Apache License v2.0 which accompanies this distribution. +// +// The Eclipse Public License is available at +// http://www.eclipse.org/legal/epl-v10.html +// +// The Apache License v2.0 is available at +// http://www.opensource.org/licenses/apache2.0.php +// +// You may elect to redistribute this code under either of these licenses. +// ======================================================================== +// + +package org.eclipse.jetty.proxy; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InterruptedIOException; +import java.net.URLDecoder; +import java.net.URLEncoder; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPOutputStream; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.HttpProxy; +import org.eclipse.jetty.client.api.ContentProvider; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.api.Response; +import org.eclipse.jetty.client.api.Result; +import org.eclipse.jetty.client.util.BytesContentProvider; +import org.eclipse.jetty.client.util.DeferredContentProvider; +import org.eclipse.jetty.client.util.FutureResponseListener; +import org.eclipse.jetty.http.HttpHeader; +import org.eclipse.jetty.server.HttpConfiguration; +import org.eclipse.jetty.server.HttpConnectionFactory; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.eclipse.jetty.toolchain.test.TestTracker; +import org.eclipse.jetty.util.Utf8StringBuilder; +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; +import org.eclipse.jetty.util.thread.QueuedThreadPool; +import org.junit.After; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +public class AsyncMiddleManServletTest +{ + private static final Logger LOG = Log.getLogger(AsyncMiddleManServletTest.class); + @Rule + public final TestTracker tracker = new TestTracker(); + private HttpClient client; + private Server proxy; + private ServerConnector proxyConnector; + private ServletContextHandler proxyContext; + private Server server; + private ServerConnector serverConnector; + + private void prepareProxy(HttpServlet proxyServlet) throws Exception + { + prepareProxy(proxyServlet, new HashMap()); + } + + private void prepareProxy(HttpServlet proxyServlet, Map initParams) throws Exception + { + QueuedThreadPool proxyPool = new QueuedThreadPool(); + proxyPool.setName("proxy"); + proxy = new Server(proxyPool); + + HttpConfiguration configuration = new HttpConfiguration(); + configuration.setSendDateHeader(false); + configuration.setSendServerVersion(false); + String value = initParams.get("outputBufferSize"); + if (value != null) + configuration.setOutputBufferSize(Integer.valueOf(value)); + proxyConnector = new ServerConnector(proxy, new HttpConnectionFactory(configuration)); + proxy.addConnector(proxyConnector); + + proxyContext = new ServletContextHandler(proxy, "/", true, false); + ServletHolder proxyServletHolder = new ServletHolder(proxyServlet); + proxyServletHolder.setInitParameters(initParams); + proxyContext.addServlet(proxyServletHolder, "/*"); + + proxy.start(); + } + + private void prepareClient() throws Exception + { + QueuedThreadPool clientPool = new QueuedThreadPool(); + clientPool.setName("client"); + client = new HttpClient(); + client.setExecutor(clientPool); + client.getProxyConfiguration().getProxies().add(new HttpProxy("localhost", proxyConnector.getLocalPort())); + client.start(); + } + + private void prepareServer(HttpServlet servlet) throws Exception + { + QueuedThreadPool serverPool = new QueuedThreadPool(); + serverPool.setName("server"); + server = new Server(serverPool); + serverConnector = new ServerConnector(server); + server.addConnector(serverConnector); + + ServletContextHandler appCtx = new ServletContextHandler(server, "/", true, false); + ServletHolder appServletHolder = new ServletHolder(servlet); + appCtx.addServlet(appServletHolder, "/*"); + + server.start(); + } + + @After + public void disposeProxy() throws Exception + { + client.stop(); + proxy.stop(); + } + + @After + public void disposeServer() throws Exception + { + server.stop(); + } + + @Test + public void testClientRequestSmallContentKnownLengthGzipped() throws Exception + { + // Lengths smaller than the buffer sizes preserve the Content-Length header. + testClientRequestContentKnownLengthGzipped(1024, false); + } + + @Test + public void testClientRequestLargeContentKnownLengthGzipped() throws Exception + { + // Lengths bigger than the buffer sizes will force chunked mode. + testClientRequestContentKnownLengthGzipped(1024 * 1024, true); + } + + private void testClientRequestContentKnownLengthGzipped(int length, final boolean expectChunked) throws Exception + { + byte[] bytes = new byte[length]; + new Random().nextBytes(bytes); + + prepareServer(new EchoHttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + String transferEncoding = request.getHeader(HttpHeader.TRANSFER_ENCODING.asString()); + if (expectChunked) + Assert.assertNotNull(transferEncoding); + else + Assert.assertNull(transferEncoding); + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + super.service(request, response); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new GZIPContentTransformer(ContentTransformer.IDENTITY); + } + }); + prepareClient(); + + byte[] gzipBytes = gzip(bytes); + ContentProvider gzipContent = new BytesContentProvider(gzipBytes); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .header(HttpHeader.CONTENT_ENCODING, "gzip") + .content(gzipContent) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertArrayEquals(bytes, response.getContent()); + } + + @Test + public void testServerResponseContentKnownLengthGzipped() throws Exception + { + byte[] bytes = new byte[1024]; + new Random().nextBytes(bytes); + final byte[] gzipBytes = gzip(bytes); + + prepareServer(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + response.getOutputStream().write(gzipBytes); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return new GZIPContentTransformer(ContentTransformer.IDENTITY); + } + }); + prepareClient(); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertArrayEquals(bytes, response.getContent()); + } + + @Test + public void testTransformUpstreamAndDownstreamKnownContentLengthGzipped() throws Exception + { + String data = "Google"; + byte[] bytes = data.getBytes(StandardCharsets.UTF_8); + + prepareServer(new EchoHttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + super.service(request, response); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new GZIPContentTransformer(new HrefTransformer.Client()); + } + + @Override + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return new GZIPContentTransformer(new HrefTransformer.Server()); + } + }); + prepareClient(); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .header(HttpHeader.CONTENT_ENCODING, "gzip") + .content(new BytesContentProvider(gzip(bytes))) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertArrayEquals(bytes, response.getContent()); + } + + @Test + public void testManySequentialTransformations() throws Exception + { + for (int i = 0; i < 8; ++i) + testTransformUpstreamAndDownstreamKnownContentLengthGzipped(); + } + + @Test + public void testUpstreamTransformationBufferedGzipped() throws Exception + { + prepareServer(new EchoHttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + super.service(request, response); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new GZIPContentTransformer(new BufferingContentTransformer()); + } + }); + prepareClient(); + + DeferredContentProvider content = new DeferredContentProvider(); + Request request = client.newRequest("localhost", serverConnector.getLocalPort()); + FutureResponseListener listener = new FutureResponseListener(request); + request.header(HttpHeader.CONTENT_ENCODING, "gzip") + .content(content) + .send(listener); + byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8); + content.offer(ByteBuffer.wrap(gzip(bytes))); + sleep(1000); + content.close(); + + ContentResponse response = listener.get(5, TimeUnit.SECONDS); + Assert.assertEquals(200, response.getStatus()); + Assert.assertArrayEquals(bytes, response.getContent()); + } + + @Test + public void testDownstreamTransformationBufferedGzipped() throws Exception + { + prepareServer(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + + ServletInputStream input = request.getInputStream(); + ServletOutputStream output = response.getOutputStream(); + int read; + while ((read = input.read()) >= 0) + { + output.write(read); + output.flush(); + } + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return new GZIPContentTransformer(new BufferingContentTransformer()); + } + }); + prepareClient(); + + byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8); + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .header(HttpHeader.CONTENT_ENCODING, "gzip") + .content(new BytesContentProvider(gzip(bytes))) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertArrayEquals(bytes, response.getContent()); + } + + @Test + public void testDiscardUpstreamAndDownstreamKnownContentLengthGzipped() throws Exception + { + final byte[] bytes = "ABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8); + prepareServer(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + Assert.assertEquals(-1, request.getInputStream().read()); + response.setHeader(HttpHeader.CONTENT_ENCODING.asString(), "gzip"); + response.getOutputStream().write(gzip(bytes)); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new GZIPContentTransformer(new DiscardContentTransformer()); + } + + @Override + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return new GZIPContentTransformer(new DiscardContentTransformer()); + } + }); + prepareClient(); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .header(HttpHeader.CONTENT_ENCODING, "gzip") + .content(new BytesContentProvider(gzip(bytes))) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(0, response.getContent().length); + } + + @Test + public void testUpstreamTransformationThrowsBeforeCommittingProxyRequest() throws Exception + { + prepareServer(new EchoHttpServlet()); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new ContentTransformer() + { + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + throw new NullPointerException("explicitly_thrown_by_test"); + } + }; + } + }); + prepareClient(); + + byte[] bytes = new byte[1024]; + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .content(new BytesContentProvider(bytes)) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(500, response.getStatus()); + } + + @Test + public void testUpstreamTransformationThrowsAfterCommittingProxyRequest() throws Exception + { + prepareServer(new EchoHttpServlet()); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newClientRequestContentTransformer(HttpServletRequest clientRequest, Request proxyRequest) + { + return new ContentTransformer() + { + private int count; + + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + if (++count < 2) + output.add(input); + else + throw new NullPointerException("explicitly_thrown_by_test"); + } + }; + } + }); + prepareClient(); + + final CountDownLatch latch = new CountDownLatch(1); + DeferredContentProvider content = new DeferredContentProvider(); + client.newRequest("localhost", serverConnector.getLocalPort()) + .content(content) + .send(new Response.CompleteListener() + { + @Override + public void onComplete(Result result) + { + if (result.isSucceeded() && result.getResponse().getStatus() == 502) + latch.countDown(); + } + }); + + content.offer(ByteBuffer.allocate(512)); + sleep(1000); + content.offer(ByteBuffer.allocate(512)); + content.close(); + + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + @Test + public void testDownstreamTransformationThrowsAtOnContent() throws Exception + { + testDownstreamTransformationThrows(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + // To trigger the test failure we need that onContent() + // is called twice, so the second time the test throws. + ServletOutputStream output = response.getOutputStream(); + output.write(new byte[512]); + output.flush(); + output.write(new byte[512]); + output.flush(); + } + }); + } + + @Test + public void testDownstreamTransformationThrowsAtOnSuccess() throws Exception + { + testDownstreamTransformationThrows(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + // To trigger the test failure we need that onContent() + // is called only once, so the the test throws from onSuccess(). + ServletOutputStream output = response.getOutputStream(); + output.write(new byte[512]); + output.flush(); + } + }); + } + + private void testDownstreamTransformationThrows(HttpServlet serverServlet) throws Exception + { + prepareServer(serverServlet); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected ContentTransformer newServerResponseContentTransformer(HttpServletRequest clientRequest, HttpServletResponse proxyResponse, Response serverResponse) + { + return new ContentTransformer() + { + private int count; + + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + if (++count < 2) + output.add(input); + else + throw new NullPointerException("explicitly_thrown_by_test"); + } + }; + } + }); + prepareClient(); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(502, response.getStatus()); + } + + @Test + public void testClientRequestReadFailsOnFirstRead() throws Exception + { + prepareServer(new EchoHttpServlet()); + prepareProxy(new AsyncMiddleManServlet() + { + @Override + protected int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException + { + throw new IOException("explicitly_thrown_by_test"); + } + }); + prepareClient(); + + final CountDownLatch latch = new CountDownLatch(1); + DeferredContentProvider content = new DeferredContentProvider(); + client.newRequest("localhost", serverConnector.getLocalPort()) + .content(content) + .send(new Response.CompleteListener() + { + @Override + public void onComplete(Result result) + { + System.err.println(result); + if (result.getResponse().getStatus() == 500) + latch.countDown(); + } + }); + content.offer(ByteBuffer.allocate(512)); + sleep(1000); + content.offer(ByteBuffer.allocate(512)); + content.close(); + + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + @Test + public void testClientRequestReadFailsOnSecondRead() throws Exception + { + prepareServer(new EchoHttpServlet()); + prepareProxy(new AsyncMiddleManServlet() + { + private int count; + + @Override + protected int readClientRequestContent(ServletInputStream input, byte[] buffer) throws IOException + { + if (++count < 2) + return super.readClientRequestContent(input, buffer); + else + throw new IOException("explicitly_thrown_by_test"); + } + }); + prepareClient(); + + final CountDownLatch latch = new CountDownLatch(1); + DeferredContentProvider content = new DeferredContentProvider(); + client.newRequest("localhost", serverConnector.getLocalPort()) + .content(content) + .send(new Response.CompleteListener() + { + @Override + public void onComplete(Result result) + { + if (result.getResponse().getStatus() == 502) + latch.countDown(); + } + }); + content.offer(ByteBuffer.allocate(512)); + sleep(1000); + content.offer(ByteBuffer.allocate(512)); + content.close(); + + Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + @Test + public void testProxyResponseWriteFailsOnFirstWrite() throws Exception + { + testProxyResponseWriteFails(1); + } + + @Test + public void testProxyResponseWriteFailsOnSecondWrite() throws Exception + { + testProxyResponseWriteFails(2); + } + + private void testProxyResponseWriteFails(final int writeCount) throws Exception + { + prepareServer(new HttpServlet() + { + @Override + protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException + { + ServletOutputStream output = response.getOutputStream(); + output.write(new byte[512]); + output.flush(); + output.write(new byte[512]); + } + }); + prepareProxy(new AsyncMiddleManServlet() + { + private int count; + + @Override + protected void writeProxyResponseContent(ServletOutputStream output, ByteBuffer content) throws IOException + { + if (++count < writeCount) + super.writeProxyResponseContent(output, content); + else + throw new IOException("explicitly_thrown_by_test"); + } + }); + prepareClient(); + + ContentResponse response = client.newRequest("localhost", serverConnector.getLocalPort()) + .timeout(5, TimeUnit.SECONDS) + .send(); + + Assert.assertEquals(502, response.getStatus()); + } + + private void sleep(long delay) throws IOException + { + try + { + Thread.sleep(delay); + } + catch (InterruptedException x) + { + throw new InterruptedIOException(); + } + } + + private byte[] gzip(byte[] bytes) throws IOException + { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (GZIPOutputStream gzipOut = new GZIPOutputStream(out)) + { + gzipOut.write(bytes); + } + return out.toByteArray(); + } + + private static abstract class HrefTransformer implements AsyncMiddleManServlet.ContentTransformer + { + private static final String PREFIX = "http://localhost/q="; + private final HrefParser parser = new HrefParser(); + private final List matches = new ArrayList<>(); + private boolean matching; + + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + int position = input.position(); + while (input.hasRemaining()) + { + boolean match = parser.parse(input); + + // Get the slice of what has been parsed so far. + int limit = input.limit(); + input.limit(input.position()); + input.position(position); + ByteBuffer slice = input.slice(); + input.position(input.limit()); + input.limit(limit); + position = input.position(); + + if (matching) + { + if (match) + { + ByteBuffer copy = ByteBuffer.allocate(slice.remaining()); + copy.put(slice).flip(); + matches.add(copy); + } + else + { + matching = false; + + // Transform the matches. + Utf8StringBuilder builder = new Utf8StringBuilder(); + for (ByteBuffer buffer : matches) + builder.append(buffer); + String transformed = transform(builder.toString()); + output.add(ByteBuffer.wrap(transformed.getBytes(StandardCharsets.UTF_8))); + output.add(slice); + } + } + else + { + if (match) + { + matching = true; + ByteBuffer copy = ByteBuffer.allocate(slice.remaining()); + copy.put(slice).flip(); + matches.add(copy); + } + else + { + output.add(slice); + } + } + } + } + + protected abstract String transform(String value) throws IOException; + + private static class Client extends HrefTransformer + { + @Override + protected String transform(String value) throws IOException + { + String result = PREFIX + URLEncoder.encode(value, "UTF-8"); + LOG.debug("{} -> {}", value, result); + return result; + } + } + + private static class Server extends HrefTransformer + { + @Override + protected String transform(String value) throws IOException + { + String result = URLDecoder.decode(value.substring(PREFIX.length()), "UTF-8"); + LOG.debug("{} <- {}", value, result); + return result; + } + } + } + + private static class HrefParser + { + private final byte[] token = {'h', 'r', 'e', 'f', '=', '"'}; + private int state; + + private boolean parse(ByteBuffer buffer) + { + while (buffer.hasRemaining()) + { + int current = buffer.get() & 0xFF; + if (state < token.length) + { + if (Character.toLowerCase(current) != token[state]) + { + state = 0; + continue; + } + + ++state; + if (state == token.length) + return false; + } + else + { + // Look for the ending quote. + if (current == '"') + { + buffer.position(buffer.position() - 1); + state = 0; + return true; + } + } + } + return state == token.length; + } + } + + private static class BufferingContentTransformer implements AsyncMiddleManServlet.ContentTransformer + { + private final List buffers = new ArrayList<>(); + + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + if (input.hasRemaining()) + { + ByteBuffer copy = ByteBuffer.allocate(input.remaining()); + copy.put(input).flip(); + buffers.add(copy); + } + + if (finished) + { + Assert.assertFalse(buffers.isEmpty()); + output.addAll(buffers); + buffers.clear(); + } + } + } + + private static class DiscardContentTransformer implements AsyncMiddleManServlet.ContentTransformer + { + @Override + public void transform(ByteBuffer input, boolean finished, List output) throws IOException + { + } + } +} diff --git a/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ProxyServletTest.java b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ProxyServletTest.java index 6d154755cc4..da4548f81af 100644 --- a/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ProxyServletTest.java +++ b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ProxyServletTest.java @@ -104,7 +104,8 @@ public class ProxyServletTest { return Arrays.asList(new Object[][]{ {ProxyServlet.class}, - {AsyncProxyServlet.class} + {AsyncProxyServlet.class}, + {AsyncMiddleManServlet.class} }); } @@ -114,13 +115,13 @@ public class ProxyServletTest private Server proxy; private ServerConnector proxyConnector; private ServletContextHandler proxyContext; - private ProxyServlet proxyServlet; + private AbstractProxyServlet proxyServlet; private Server server; private ServerConnector serverConnector; public ProxyServletTest(Class proxyServletClass) throws Exception { - this.proxyServlet = (ProxyServlet)proxyServletClass.newInstance(); + this.proxyServlet = (AbstractProxyServlet)proxyServletClass.newInstance(); } private void prepareProxy() throws Exception @@ -763,12 +764,12 @@ public class ProxyServletTest } @Override - protected void onResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) + protected void onProxyResponseSuccess(HttpServletRequest request, HttpServletResponse response, Response proxyResponse) { byte[] content = temp.remove(request.getRequestURI()).toByteArray(); ContentResponse cached = new HttpContentResponse(proxyResponse, content, null, null); cache.put(request.getRequestURI(), cached); - super.onResponseSuccess(request, response, proxyResponse); + super.onProxyResponseSuccess(request, response, proxyResponse); } }; prepareProxy();