Extract common http logic to server (#31311)

This is related to #28898. With the addition of the http nio transport,
we now have two different modules that provide http transports.
Currently most of the http logic lives at the module level. However,
some of this logic can live in server. In particular, some of the
setting of headers, cors, and pipelining. This commit begins this moving
in that direction by introducing lower level abstraction (HttpChannel,
HttpRequest, and HttpResonse) that is implemented by the modules. The
higher level rest request and rest channel work can live entirely in
server.
This commit is contained in:
Tim Brooks 2018-06-14 15:10:02 -06:00 committed by GitHub
parent 6dd81ead74
commit fcf1e41e42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 2111 additions and 2277 deletions

View File

@ -19,252 +19,58 @@
package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.transport.netty4.Netty4Utils;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.net.InetSocketAddress;
final class Netty4HttpChannel extends AbstractRestChannel {
public class Netty4HttpChannel implements HttpChannel {
private final Netty4HttpServerTransport transport;
private final Channel channel;
private final FullHttpRequest nettyRequest;
private final int sequence;
private final ThreadContext threadContext;
private final HttpHandlingSettings handlingSettings;
/**
* @param transport The corresponding <code>NettyHttpServerTransport</code> where this channel belongs to.
* @param request The request that is handled by this channel.
* @param sequence The pipelining sequence number for this request
* @param handlingSettings true if error messages should include stack traces.
* @param threadContext the thread context for the channel
*/
Netty4HttpChannel(Netty4HttpServerTransport transport, Netty4HttpRequest request, int sequence, HttpHandlingSettings handlingSettings,
ThreadContext threadContext) {
super(request, handlingSettings.getDetailedErrorsEnabled());
this.transport = transport;
this.channel = request.getChannel();
this.nettyRequest = request.request();
this.sequence = sequence;
this.threadContext = threadContext;
this.handlingSettings = handlingSettings;
Netty4HttpChannel(Channel channel) {
this.channel = channel;
}
@Override
protected BytesStreamOutput newBytesOutput() {
return new ReleasableBytesStreamOutput(transport.bigArrays);
public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
ChannelPromise writePromise = channel.newPromise();
writePromise.addListener(f -> {
if (f.isSuccess()) {
listener.onResponse(null);
} else {
final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause);
if (cause instanceof Error) {
listener.onFailure(new Exception(cause));
} else {
listener.onFailure((Exception) cause);
}
}
});
channel.writeAndFlush(response, writePromise);
}
@Override
public void sendResponse(RestResponse response) {
// if the response object was created upstream, then use it;
// otherwise, create a new one
ByteBuf buffer = Netty4Utils.toByteBuf(response.content());
final FullHttpResponse resp;
if (HttpMethod.HEAD.equals(nettyRequest.method())) {
resp = newResponse(Unpooled.EMPTY_BUFFER);
} else {
resp = newResponse(buffer);
}
resp.setStatus(getStatus(response.status()));
Netty4CorsHandler.setCorsResponseHeaders(nettyRequest, resp, transport.getCorsConfig());
String opaque = nettyRequest.headers().get("X-Opaque-Id");
if (opaque != null) {
setHeaderField(resp, "X-Opaque-Id", opaque);
}
// Add all custom headers
addCustomHeaders(resp, response.getHeaders());
addCustomHeaders(resp, threadContext.getResponseHeaders());
BytesReference content = response.content();
boolean releaseContent = content instanceof Releasable;
boolean releaseBytesStreamOutput = bytesOutputOrNull() instanceof ReleasableBytesStreamOutput;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false);
addCookies(resp);
final ChannelPromise promise = channel.newPromise();
if (releaseContent) {
promise.addListener(f -> ((Releasable) content).close());
}
if (releaseBytesStreamOutput) {
promise.addListener(f -> bytesOutputOrNull().close());
}
if (isCloseConnection()) {
promise.addListener(ChannelFutureListener.CLOSE);
}
Netty4HttpResponse newResponse = new Netty4HttpResponse(sequence, resp);
channel.writeAndFlush(newResponse, promise);
releaseContent = false;
releaseBytesStreamOutput = false;
} finally {
if (releaseContent) {
((Releasable) content).close();
}
if (releaseBytesStreamOutput) {
bytesOutputOrNull().close();
}
}
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) channel.localAddress();
}
private void setHeaderField(HttpResponse resp, String headerField, String value) {
setHeaderField(resp, headerField, value, true);
@Override
public InetSocketAddress getRemoteAddress() {
return (InetSocketAddress) channel.remoteAddress();
}
private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) {
if (override || !resp.headers().contains(headerField)) {
resp.headers().add(headerField, value);
}
@Override
public void close() {
channel.close();
}
private void addCookies(HttpResponse resp) {
if (handlingSettings.isResetCookies()) {
String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<io.netty.handler.codec.http.cookie.Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
// Reset the cookies if necessary.
resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies));
}
}
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0);
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) ||
(http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)));
}
// Create a new {@link HttpResponse} to transmit the response for the netty request.
private FullHttpResponse newResponse(ByteBuf buffer) {
final boolean http10 = isHttp10();
final boolean close = isCloseConnection();
// Build the response object.
final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize
final FullHttpResponse response;
if (http10) {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer);
if (!close) {
response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive");
}
} else {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer);
}
return response;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
public Channel getNettyChannel() {
return channel;
}
}

View File

@ -66,7 +66,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
try {
List<Tuple<Netty4HttpResponse, ChannelPromise>> readyResponses = aggregator.write(response, promise);
for (Tuple<Netty4HttpResponse, ChannelPromise> readyResponse : readyResponses) {
ctx.write(readyResponse.v1().getResponse(), readyResponse.v2());
ctx.write(readyResponse.v1(), readyResponse.v2());
}
success = true;
} catch (IllegalStateException e) {

View File

@ -19,17 +19,22 @@
package org.elasticsearch.http.netty4;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;
import java.net.SocketAddress;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Collections;
@ -38,25 +43,16 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class Netty4HttpRequest extends RestRequest {
public class Netty4HttpRequest implements HttpRequest {
private final FullHttpRequest request;
private final Channel channel;
private final BytesReference content;
private final HttpHeadersMap headers;
private final int sequence;
/**
* Construct a new request.
*
* @param xContentRegistry the content registry
* @param request the underlying request
* @param channel the channel for the request
* @throws BadParameterException if the parameters can not be decoded
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) {
super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers()));
Netty4HttpRequest(FullHttpRequest request, int sequence) {
this.request = request;
this.channel = channel;
headers = new HttpHeadersMap(request.headers());
this.sequence = sequence;
if (request.content().isReadable()) {
this.content = Netty4Utils.toBytesReference(request.content());
} else {
@ -64,71 +60,39 @@ public class Netty4HttpRequest extends RestRequest {
}
}
/**
* Construct a new request. In contrast to
* {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so
* this constructor will not throw a {@link BadParameterException}.
*
* @param xContentRegistry the content registry
* @param params the parameters for the request
* @param uri the path for the request
* @param request the underlying request
* @param channel the channel for the request
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
Netty4HttpRequest(
final NamedXContentRegistry xContentRegistry,
final Map<String, String> params,
final String uri,
final FullHttpRequest request,
final Channel channel) {
super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers()));
this.request = request;
this.channel = channel;
if (request.content().isReadable()) {
this.content = Netty4Utils.toBytesReference(request.content());
} else {
this.content = BytesArray.EMPTY;
}
}
public FullHttpRequest request() {
return this.request;
}
@Override
public Method method() {
public RestRequest.Method method() {
HttpMethod httpMethod = request.method();
if (httpMethod == HttpMethod.GET)
return Method.GET;
return RestRequest.Method.GET;
if (httpMethod == HttpMethod.POST)
return Method.POST;
return RestRequest.Method.POST;
if (httpMethod == HttpMethod.PUT)
return Method.PUT;
return RestRequest.Method.PUT;
if (httpMethod == HttpMethod.DELETE)
return Method.DELETE;
return RestRequest.Method.DELETE;
if (httpMethod == HttpMethod.HEAD) {
return Method.HEAD;
return RestRequest.Method.HEAD;
}
if (httpMethod == HttpMethod.OPTIONS) {
return Method.OPTIONS;
return RestRequest.Method.OPTIONS;
}
if (httpMethod == HttpMethod.PATCH) {
return Method.PATCH;
return RestRequest.Method.PATCH;
}
if (httpMethod == HttpMethod.TRACE) {
return Method.TRACE;
return RestRequest.Method.TRACE;
}
if (httpMethod == HttpMethod.CONNECT) {
return Method.CONNECT;
return RestRequest.Method.CONNECT;
}
throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
@ -139,40 +103,64 @@ public class Netty4HttpRequest extends RestRequest {
return request.uri();
}
@Override
public boolean hasContent() {
return content.length() > 0;
}
@Override
public BytesReference content() {
return content;
}
/**
* Returns the remote address where this rest request channel is "connected to". The
* returned {@link SocketAddress} is supposed to be down-cast into more
* concrete type such as {@link java.net.InetSocketAddress} to retrieve
* the detailed information.
*/
@Override
public SocketAddress getRemoteAddress() {
return channel.remoteAddress();
public final Map<String, List<String>> getHeaders() {
return headers;
}
/**
* Returns the local address where this request channel is bound to. The returned
* {@link SocketAddress} is supposed to be down-cast into more concrete
* type such as {@link java.net.InetSocketAddress} to retrieve the detailed
* information.
*/
@Override
public SocketAddress getLocalAddress() {
return channel.localAddress();
public List<String> strictCookies() {
String cookieString = request.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
return ServerCookieEncoder.STRICT.encode(cookies);
}
}
return Collections.emptyList();
}
public Channel getChannel() {
return channel;
@Override
public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpRequest.HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
}
}
@Override
public HttpRequest removeHeader(String header) {
HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove(header);
HttpHeaders trailingHeaders = new DefaultHttpHeaders();
trailingHeaders.add(request.trailingHeaders());
trailingHeaders.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(),
request.content(), headersWithoutContentTypeHeader, trailingHeaders);
return new Netty4HttpRequest(requestWithoutHeader, sequence);
}
@Override
public Netty4HttpResponse createResponse(RestStatus status, BytesReference content) {
return new Netty4HttpResponse(this, status, content);
}
public FullHttpRequest nettyRequest() {
return request;
}
int sequence() {
return sequence;
}
/**
@ -249,7 +237,7 @@ public class Netty4HttpRequest extends RestRequest {
@Override
public Set<Entry<String, List<String>>> entrySet() {
return httpHeaders.names().stream().map(k -> new AbstractMap.SimpleImmutableEntry<>(k, httpHeaders.getAll(k)))
.collect(Collectors.toSet());
.collect(Collectors.toSet());
}
}
}

View File

@ -20,112 +20,51 @@
package org.elasticsearch.http.netty4;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.netty4.Netty4Utils;
import java.util.Collections;
@ChannelHandler.Sharable
class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
private final Netty4HttpServerTransport serverTransport;
private final HttpHandlingSettings handlingSettings;
private final ThreadContext threadContext;
Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings,
ThreadContext threadContext) {
Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport) {
this.serverTransport = serverTransport;
this.handlingSettings = handlingSettings;
this.threadContext = threadContext;
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> msg) throws Exception {
final FullHttpRequest request = msg.getRequest();
Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
FullHttpRequest request = msg.getRequest();
try {
final FullHttpRequest copiedRequest =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
Unpooled.copiedBuffer(request.content()),
request.headers(),
request.trailingHeaders());
final FullHttpRequest copy =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
Unpooled.copiedBuffer(request.content()),
request.headers(),
request.trailingHeaders());
Exception badRequestCause = null;
/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final Netty4HttpRequest httpRequest;
{
Netty4HttpRequest innerHttpRequest;
try {
innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel());
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutParameters(copy, ctx.channel());
}
httpRequest = innerHttpRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these
* parameter values.
*/
final Netty4HttpChannel channel;
{
Netty4HttpChannel innerChannel;
try {
innerChannel =
new Netty4HttpChannel(serverTransport, httpRequest, msg.getSequence(), handlingSettings, threadContext);
} catch (final IllegalArgumentException e) {
if (badRequestCause == null) {
badRequestCause = e;
} else {
badRequestCause.addSuppressed(e);
}
final Netty4HttpRequest innerRequest =
new Netty4HttpRequest(
serverTransport.xContentRegistry,
Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters
copy.uri(),
copy,
ctx.channel());
innerChannel =
new Netty4HttpChannel(serverTransport, innerRequest, msg.getSequence(), handlingSettings, threadContext);
}
channel = innerChannel;
}
Netty4HttpRequest httpRequest = new Netty4HttpRequest(copiedRequest, msg.getSequence());
if (request.decoderResult().isFailure()) {
serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause());
} else if (badRequestCause != null) {
serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause);
Throwable cause = request.decoderResult().cause();
if (cause instanceof Error) {
ExceptionsHelper.dieOnError(cause);
serverTransport.incomingRequestError(httpRequest, channel, new Exception(cause));
} else {
serverTransport.incomingRequestError(httpRequest, channel, (Exception) cause);
}
} else {
serverTransport.dispatchRequest(httpRequest, channel);
serverTransport.incomingRequest(httpRequest, channel);
}
} finally {
// As we have copied the buffer, we can release the request
@ -133,32 +72,6 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
}
}
private Netty4HttpRequest requestWithoutContentTypeHeader(
final FullHttpRequest request, final Channel channel, final Exception badRequestCause) {
final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove("Content-Type");
final FullHttpRequest requestWithoutContentTypeHeader =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
request.content(),
headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again
request.trailingHeaders()); // Content-Type can not be a trailing header
try {
return new Netty4HttpRequest(serverTransport.xContentRegistry, requestWithoutContentTypeHeader, channel);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return requestWithoutParameters(requestWithoutContentTypeHeader, channel);
}
}
private Netty4HttpRequest requestWithoutParameters(final FullHttpRequest request, final Channel channel) {
// remove all parameters as at least one is incorrectly encoded
return new Netty4HttpRequest(serverTransport.xContentRegistry, Collections.emptyMap(), request.uri(), request, channel);
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
Netty4Utils.maybeDie(cause);

View File

@ -19,19 +19,103 @@
package org.elasticsearch.http.netty4;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedMessage;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;
public class Netty4HttpResponse extends HttpPipelinedMessage {
import java.util.Collections;
import java.util.EnumMap;
import java.util.Map;
private final FullHttpResponse response;
public class Netty4HttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage {
public Netty4HttpResponse(int sequence, FullHttpResponse response) {
super(sequence);
this.response = response;
private final int sequence;
private final Netty4HttpRequest request;
Netty4HttpResponse(Netty4HttpRequest request, RestStatus status, BytesReference content) {
super(request.nettyRequest().protocolVersion(), getStatus(status), Netty4Utils.toByteBuf(content));
this.sequence = request.sequence();
this.request = request;
}
public FullHttpResponse getResponse() {
return response;
@Override
public void addHeader(String name, String value) {
headers().add(name, value);
}
@Override
public boolean containsHeader(String name) {
return headers().contains(name);
}
@Override
public int getSequence() {
return sequence;
}
public Netty4HttpRequest getRequest() {
return request;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
}
}

View File

@ -39,6 +39,7 @@ import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.AttributeKey;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.common.Strings;
@ -53,9 +54,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.http.AbstractHttpServerTransport;
import org.elasticsearch.http.BindHttpException;
import org.elasticsearch.http.HttpHandlingSettings;
@ -149,38 +148,29 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
public static final Setting<ByteSizeValue> SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE =
Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope);
protected final BigArrays bigArrays;
private final ByteSizeValue maxInitialLineLength;
private final ByteSizeValue maxHeaderSize;
private final ByteSizeValue maxChunkSize;
protected final ByteSizeValue maxInitialLineLength;
protected final ByteSizeValue maxHeaderSize;
protected final ByteSizeValue maxChunkSize;
private final int workerCount;
protected final int workerCount;
private final int pipeliningMaxEvents;
protected final int pipeliningMaxEvents;
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean reuseAddress;
/**
* The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}.
*/
protected final NamedXContentRegistry xContentRegistry;
protected final boolean tcpNoDelay;
protected final boolean tcpKeepAlive;
protected final boolean reuseAddress;
protected final ByteSizeValue tcpSendBufferSize;
protected final ByteSizeValue tcpReceiveBufferSize;
protected final RecvByteBufAllocator recvByteBufAllocator;
private final ByteSizeValue tcpSendBufferSize;
private final ByteSizeValue tcpReceiveBufferSize;
private final RecvByteBufAllocator recvByteBufAllocator;
private final int readTimeoutMillis;
protected final int maxCompositeBufferComponents;
private final int maxCompositeBufferComponents;
protected volatile ServerBootstrap serverBootstrap;
protected final List<Channel> serverChannels = new ArrayList<>();
protected final HttpHandlingSettings httpHandlingSettings;
// package private for testing
Netty4OpenChannelsHandler serverOpenChannels;
@ -189,16 +179,13 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) {
super(settings, networkService, threadPool, dispatcher);
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings));
this.bigArrays = bigArrays;
this.xContentRegistry = xContentRegistry;
this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings);
this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings);
this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings);
@ -398,26 +385,27 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
}
public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext());
return new HttpChannelHandler(this, handlingSettings);
}
static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
private final Netty4HttpServerTransport transport;
private final Netty4HttpRequestHandler requestHandler;
private final HttpHandlingSettings handlingSettings;
protected HttpChannelHandler(
final Netty4HttpServerTransport transport,
final HttpHandlingSettings handlingSettings,
final ThreadContext threadContext) {
protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) {
this.transport = transport;
this.handlingSettings = handlingSettings;
this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext);
this.requestHandler = new Netty4HttpRequestHandler(transport);
}
@Override
protected void initChannel(Channel ch) throws Exception {
Netty4HttpChannel nettyTcpChannel = new Netty4HttpChannel(ch);
ch.attr(HTTP_CHANNEL_KEY).set(nettyTcpChannel);
ch.pipeline().addLast("openChannels", transport.serverOpenChannels);
ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS));
final HttpRequestDecoder decoder = new HttpRequestDecoder(

View File

@ -22,6 +22,7 @@ package org.elasticsearch.http.netty4.cors;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings;
import org.elasticsearch.http.netty4.Netty4HttpResponse;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@ -76,6 +78,14 @@ public class Netty4CorsHandler extends ChannelDuplexHandler {
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass();
Netty4HttpResponse response = (Netty4HttpResponse) msg;
setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config);
ctx.write(response, promise);;
}
public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) {
if (!config.isCorsSupportEnabled()) {
return;

View File

@ -333,10 +333,10 @@ public class Netty4Transport extends TcpTransport {
addClosedExceptionLogger(ch);
NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name);
ch.attr(CHANNEL_KEY).set(nettyTcpChannel);
serverAcceptedChannel(nettyTcpChannel);
ch.pipeline().addLast("logging", new ESLoggingHandler());
ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder());
ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name));
serverAcceptedChannel(nettyTcpChannel);
}
@Override

View File

@ -98,8 +98,11 @@ public class NettyTcpChannel implements TcpChannel {
} else {
final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause);
assert cause instanceof Exception;
listener.onFailure((Exception) cause);
if (cause instanceof Error) {
listener.onFailure(new Exception(cause));
} else {
listener.onFailure((Exception) cause);
}
}
});
channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise);

View File

@ -0,0 +1,148 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class Netty4CorsTests extends ESTestCase {
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = Netty4CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
EmbeddedChannel embeddedChannel = new EmbeddedChannel();
embeddedChannel.pipeline().addLast(new Netty4CorsHandler(Netty4HttpServerTransport.buildCorsConfig(settings)));
Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest, 0);
embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content")));
return embeddedChannel.readOutbound();
}
}

View File

@ -1,616 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelProgressivePromise;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasablePagedBytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ByteArray;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class Netty4HttpChannelTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Collections.emptyList());
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(Netty4Utils.toByteBuf(new TestResponse().content())));
}
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = Netty4CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
public void testHeadersSet() {
Settings settings = Settings.builder().build();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(),
new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote");
final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
// send a response
Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
TestResponse resp = new TestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
HttpResponse response = ((Netty4HttpResponse) writtenObjects.get(0)).getResponse();
assertThat(response.headers().get("non-existent-header"), nullValue());
assertThat(response.headers().get(customHeader), equalTo(customHeaderValue));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length())));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType()));
}
}
public void testReleaseOnSendToClosedChannel() {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) {
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final TestResponse response = new TestResponse(bigArrays);
assertThat(response.content(), instanceOf(Releasable.class));
embeddedChannel.close();
channel.sendResponse(response);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
}
public void testReleaseOnSendToChannelAfterException() throws IOException {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) {
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
}
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}
} else {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/");
if (!close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
}
}
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, embeddedChannel);
// send a response, the channel close status should match
assertTrue(embeddedChannel.isOpen());
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final TestResponse resp = new TestResponse();
channel.sendResponse(resp);
assertThat(embeddedChannel.isOpen(), equalTo(!close));
}
}
private FullHttpResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(),
new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
final Netty4HttpRequest request =
new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
channel.sendResponse(new TestResponse());
// get the response
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
return ((Netty4HttpResponse) writtenObjects.get(0)).getResponse();
}
}
private static class WriteCapturingChannel implements Channel {
private List<Object> writtenObjects = new ArrayList<>();
@Override
public ChannelId id() {
return null;
}
@Override
public EventLoop eventLoop() {
return null;
}
@Override
public Channel parent() {
return null;
}
@Override
public ChannelConfig config() {
return null;
}
@Override
public boolean isOpen() {
return false;
}
@Override
public boolean isRegistered() {
return false;
}
@Override
public boolean isActive() {
return false;
}
@Override
public ChannelMetadata metadata() {
return null;
}
@Override
public SocketAddress localAddress() {
return null;
}
@Override
public SocketAddress remoteAddress() {
return null;
}
@Override
public ChannelFuture closeFuture() {
return null;
}
@Override
public boolean isWritable() {
return false;
}
@Override
public long bytesBeforeUnwritable() {
return 0;
}
@Override
public long bytesBeforeWritable() {
return 0;
}
@Override
public Unsafe unsafe() {
return null;
}
@Override
public ChannelPipeline pipeline() {
return null;
}
@Override
public ByteBufAllocator alloc() {
return null;
}
@Override
public Channel read() {
return null;
}
@Override
public Channel flush() {
return null;
}
@Override
public ChannelFuture bind(SocketAddress localAddress) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) {
return null;
}
@Override
public ChannelFuture disconnect() {
return null;
}
@Override
public ChannelFuture close() {
return null;
}
@Override
public ChannelFuture deregister() {
return null;
}
@Override
public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture disconnect(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture close(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture deregister(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture write(Object msg) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture write(Object msg, ChannelPromise promise) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture writeAndFlush(Object msg) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelPromise newPromise() {
return null;
}
@Override
public ChannelProgressivePromise newProgressivePromise() {
return null;
}
@Override
public ChannelFuture newSucceededFuture() {
return null;
}
@Override
public ChannelFuture newFailedFuture(Throwable cause) {
return null;
}
@Override
public ChannelPromise voidPromise() {
return null;
}
@Override
public <T> Attribute<T> attr(AttributeKey<T> key) {
return null;
}
@Override
public <T> boolean hasAttr(AttributeKey<T> key) {
return false;
}
@Override
public int compareTo(Channel o) {
return 0;
}
List<Object> getWrittenObjects() {
return writtenObjects;
}
}
private static class TestResponse extends RestResponse {
private final BytesReference reference;
TestResponse() {
reference = Netty4Utils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8));
}
TestResponse(final BigArrays bigArrays) {
final byte[] bytes;
try {
bytes = "content".getBytes("UTF-8");
} catch (final UnsupportedEncodingException e) {
throw new AssertionError(e);
}
final ByteArray bigArray = bigArrays.newByteArray(bytes.length);
bigArray.set(0, bytes, 0, bytes.length);
reference = new ReleasablePagedBytesReference(bigArrays, bigArray, bytes.length, Releasables.releaseOnce(bigArray));
}
@Override
public String contentType() {
return "text";
}
@Override
public BytesReference content() {
return reference;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -19,15 +19,12 @@
package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
@ -55,7 +55,6 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.hamcrest.core.Is.is;
@ -191,11 +190,11 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
ArrayList<ChannelPromise> promises = new ArrayList<>();
for (int i = 1; i < requests.size(); ++i) {
final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK);
ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise);
int sequence = requests.get(i).getSequence();
Netty4HttpResponse resp = new Netty4HttpResponse(sequence, httpResponse);
HttpPipelinedRequest<FullHttpRequest> pipelinedRequest = requests.get(i);
Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
Netty4HttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.writeAndFlush(resp, promise);
}
@ -233,10 +232,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
}
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<LastHttpContent>> {
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
@Override
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<LastHttpContent> pipelinedRequest) {
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> pipelinedRequest) {
LastHttpContent request = pipelinedRequest.getRequest();
final QueryStringDecoder decoder;
if (request instanceof FullHttpRequest) {
@ -246,9 +245,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
}
final String uri = decoder.path().replace("/", "");
final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8);
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content);
httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes());
final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8));
Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
Netty4HttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content);
httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length()));
final CountDownLatch waitingLatch = new CountDownLatch(1);
waitingRequests.put(uri, waitingLatch);
@ -260,7 +260,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
waitingLatch.await(1000, TimeUnit.SECONDS);
final ChannelPromise promise = ctx.newPromise();
eventLoopService.submit(() -> {
ctx.write(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise);
ctx.write(httpResponse, promise);
finishingLatch.countDown();
});
} catch (InterruptedException e) {

View File

@ -26,22 +26,20 @@ import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
@ -120,7 +118,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
@Override
public ChannelHandler configureServerChannelHandler() {
return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext());
return new CustomHttpChannelHandler(this, executorService);
}
@Override
@ -135,8 +133,8 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
private final ExecutorService executorService;
CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) {
super(transport, transport.httpHandlingSettings, threadContext);
CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService) {
super(transport, transport.handlingSettings);
this.executorService = executorService;
}
@ -187,8 +185,9 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8);
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer);
httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes());
Netty4HttpRequest httpRequest = new Netty4HttpRequest(fullHttpRequest, pipelinedRequest.getSequence());
Netty4HttpResponse response = httpRequest.createResponse(RestStatus.OK, new BytesArray(uri.getBytes(StandardCharsets.UTF_8)));
response.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes());
final boolean slow = uri.matches("/slow/\\d+");
if (slow) {
@ -202,7 +201,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
}
final ChannelPromise promise = ctx.newPromise();
ctx.writeAndFlush(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise);
ctx.writeAndFlush(response, promise);
}
}

View File

@ -291,40 +291,6 @@ public class Netty4HttpServerTransportTests extends ESTestCase {
assertThat(causeReference.get(), instanceOf(TooLongFrameException.class));
}
public void testDispatchDoesNotModifyThreadContext() throws InterruptedException {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (Netty4HttpServerTransport transport =
new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
transport.start();
transport.dispatchRequest(null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchBadRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
public void testReadTimeout() throws Exception {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {

View File

@ -23,54 +23,38 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.rest.RestRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;
public class HttpReadWriteHandler implements ReadWriteHandler {
private final NettyAdaptor adaptor;
private final NioSocketChannel nioChannel;
private final NioHttpChannel nioHttpChannel;
private final NioHttpServerTransport transport;
private final HttpHandlingSettings settings;
private final NamedXContentRegistry xContentRegistry;
private final NioCorsConfig corsConfig;
private final ThreadContext threadContext;
HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
NamedXContentRegistry xContentRegistry, NioCorsConfig corsConfig, ThreadContext threadContext) {
this.nioChannel = nioChannel;
HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
NioCorsConfig corsConfig) {
this.nioHttpChannel = nioHttpChannel;
this.transport = transport;
this.settings = settings;
this.xContentRegistry = xContentRegistry;
this.corsConfig = corsConfig;
this.threadContext = threadContext;
List<ChannelHandler> handlers = new ArrayList<>(5);
HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(),
@ -89,7 +73,7 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents()));
adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0]));
adaptor.addCloseListener((v, e) -> nioChannel.close());
adaptor.addCloseListener((v, e) -> nioHttpChannel.close());
}
@Override
@ -150,95 +134,22 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
request.headers(),
request.trailingHeaders());
Exception badRequestCause = null;
/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final NioHttpRequest httpRequest;
{
NioHttpRequest innerHttpRequest;
try {
innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutParameters(copiedRequest);
}
httpRequest = innerHttpRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of
* these parameter values.
*/
final NioHttpChannel channel;
{
NioHttpChannel innerChannel;
int sequence = pipelinedRequest.getSequence();
BigArrays bigArrays = transport.getBigArrays();
try {
innerChannel = new NioHttpChannel(nioChannel, bigArrays, httpRequest, sequence, settings, corsConfig, threadContext);
} catch (final IllegalArgumentException e) {
if (badRequestCause == null) {
badRequestCause = e;
} else {
badRequestCause.addSuppressed(e);
}
final NioHttpRequest innerRequest =
new NioHttpRequest(
xContentRegistry,
Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters
copiedRequest.uri(),
copiedRequest);
innerChannel = new NioHttpChannel(nioChannel, bigArrays, innerRequest, sequence, settings, corsConfig, threadContext);
}
channel = innerChannel;
}
NioHttpRequest httpRequest = new NioHttpRequest(copiedRequest, pipelinedRequest.getSequence());
if (request.decoderResult().isFailure()) {
transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause());
} else if (badRequestCause != null) {
transport.dispatchBadRequest(httpRequest, channel, badRequestCause);
Throwable cause = request.decoderResult().cause();
if (cause instanceof Error) {
ExceptionsHelper.dieOnError(cause);
transport.incomingRequestError(httpRequest, nioHttpChannel, new Exception(cause));
} else {
transport.incomingRequestError(httpRequest, nioHttpChannel, (Exception) cause);
}
} else {
transport.dispatchRequest(httpRequest, channel);
transport.incomingRequest(httpRequest, nioHttpChannel);
}
} finally {
// As we have copied the buffer, we can release the request
request.release();
}
}
private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) {
final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove("Content-Type");
final FullHttpRequest requestWithoutContentTypeHeader =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
request.content(),
headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again
request.trailingHeaders()); // Content-Type can not be a trailing header
try {
return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return requestWithoutParameters(requestWithoutContentTypeHeader);
}
}
private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) {
// remove all parameters as at least one is incorrectly encoded
return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request);
}
}

View File

@ -19,244 +19,21 @@
package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.io.IOException;
import java.nio.channels.SocketChannel;
public class NioHttpChannel extends AbstractRestChannel {
public class NioHttpChannel extends NioSocketChannel implements HttpChannel {
private final BigArrays bigArrays;
private final int sequence;
private final NioCorsConfig corsConfig;
private final ThreadContext threadContext;
private final FullHttpRequest nettyRequest;
private final NioSocketChannel nioChannel;
private final boolean resetCookies;
NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, int sequence,
HttpHandlingSettings settings, NioCorsConfig corsConfig, ThreadContext threadContext) {
super(request, settings.getDetailedErrorsEnabled());
this.nioChannel = nioChannel;
this.bigArrays = bigArrays;
this.sequence = sequence;
this.corsConfig = corsConfig;
this.threadContext = threadContext;
this.nettyRequest = request.getRequest();
this.resetCookies = settings.isResetCookies();
NioHttpChannel(SocketChannel socketChannel) throws IOException {
super(socketChannel);
}
@Override
public void sendResponse(RestResponse response) {
// if the response object was created upstream, then use it;
// otherwise, create a new one
ByteBuf buffer = ByteBufUtils.toByteBuf(response.content());
final FullHttpResponse resp;
if (HttpMethod.HEAD.equals(nettyRequest.method())) {
resp = newResponse(Unpooled.EMPTY_BUFFER);
} else {
resp = newResponse(buffer);
}
resp.setStatus(getStatus(response.status()));
NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
String opaque = nettyRequest.headers().get("X-Opaque-Id");
if (opaque != null) {
setHeaderField(resp, "X-Opaque-Id", opaque);
}
// Add all custom headers
addCustomHeaders(resp, response.getHeaders());
addCustomHeaders(resp, threadContext.getResponseHeaders());
ArrayList<Releasable> toClose = new ArrayList<>(3);
boolean success = false;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false);
addCookies(resp);
BytesReference content = response.content();
if (content instanceof Releasable) {
toClose.add((Releasable) content);
}
BytesStreamOutput bytesStreamOutput = bytesOutputOrNull();
if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) {
toClose.add((Releasable) bytesStreamOutput);
}
if (isCloseConnection()) {
toClose.add(nioChannel::close);
}
BiConsumer<Void, Exception> listener = (aVoid, ex) -> Releasables.close(toClose);
nioChannel.getContext().sendMessage(new NioHttpResponse(sequence, resp), listener);
success = true;
} finally {
if (success == false) {
Releasables.close(toClose);
}
}
}
@Override
protected BytesStreamOutput newBytesOutput() {
return new ReleasableBytesStreamOutput(bigArrays);
}
private void setHeaderField(HttpResponse resp, String headerField, String value) {
setHeaderField(resp, headerField, value, true);
}
private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) {
if (override || !resp.headers().contains(headerField)) {
resp.headers().add(headerField, value);
}
}
private void addCookies(HttpResponse resp) {
if (resetCookies) {
String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
// Reset the cookies if necessary.
resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies));
}
}
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
// Create a new {@link HttpResponse} to transmit the response for the netty request.
private FullHttpResponse newResponse(ByteBuf buffer) {
final boolean http10 = isHttp10();
final boolean close = isCloseConnection();
// Build the response object.
final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize
final FullHttpResponse response;
if (http10) {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer);
if (!close) {
response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive");
}
} else {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer);
}
return response;
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0);
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) ||
(http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)));
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
getContext().sendMessage(response, ActionListener.toBiConsumer(listener));
}
}

View File

@ -68,7 +68,7 @@ public class NioHttpPipeliningHandler extends ChannelDuplexHandler {
List<Tuple<NioHttpResponse, NettyListener>> readyResponses = aggregator.write(response, listener);
success = true;
for (Tuple<NioHttpResponse, NettyListener> responseToWrite : readyResponses) {
ctx.write(responseToWrite.v1().getResponse(), responseToWrite.v2());
ctx.write(responseToWrite.v1(), responseToWrite.v2());
}
} catch (IllegalStateException e) {
ctx.channel().close();

View File

@ -19,13 +19,20 @@
package org.elasticsearch.http.nio;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.AbstractMap;
import java.util.Collection;
@ -35,25 +42,17 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class NioHttpRequest extends RestRequest {
public class NioHttpRequest implements HttpRequest {
private final FullHttpRequest request;
private final BytesReference content;
private final HttpHeadersMap headers;
private final int sequence;
NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) {
super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers()));
this.request = request;
if (request.content().isReadable()) {
this.content = ByteBufUtils.toBytesReference(request.content());
} else {
this.content = BytesArray.EMPTY;
}
}
NioHttpRequest(NamedXContentRegistry xContentRegistry, Map<String, String> params, String uri, FullHttpRequest request) {
super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers()));
NioHttpRequest(FullHttpRequest request, int sequence) {
this.request = request;
headers = new HttpHeadersMap(request.headers());
this.sequence = sequence;
if (request.content().isReadable()) {
this.content = ByteBufUtils.toBytesReference(request.content());
} else {
@ -62,38 +61,38 @@ public class NioHttpRequest extends RestRequest {
}
@Override
public Method method() {
public RestRequest.Method method() {
HttpMethod httpMethod = request.method();
if (httpMethod == HttpMethod.GET)
return Method.GET;
return RestRequest.Method.GET;
if (httpMethod == HttpMethod.POST)
return Method.POST;
return RestRequest.Method.POST;
if (httpMethod == HttpMethod.PUT)
return Method.PUT;
return RestRequest.Method.PUT;
if (httpMethod == HttpMethod.DELETE)
return Method.DELETE;
return RestRequest.Method.DELETE;
if (httpMethod == HttpMethod.HEAD) {
return Method.HEAD;
return RestRequest.Method.HEAD;
}
if (httpMethod == HttpMethod.OPTIONS) {
return Method.OPTIONS;
return RestRequest.Method.OPTIONS;
}
if (httpMethod == HttpMethod.PATCH) {
return Method.PATCH;
return RestRequest.Method.PATCH;
}
if (httpMethod == HttpMethod.TRACE) {
return Method.TRACE;
return RestRequest.Method.TRACE;
}
if (httpMethod == HttpMethod.CONNECT) {
return Method.CONNECT;
return RestRequest.Method.CONNECT;
}
throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
@ -104,20 +103,66 @@ public class NioHttpRequest extends RestRequest {
return request.uri();
}
@Override
public boolean hasContent() {
return content.length() > 0;
}
@Override
public BytesReference content() {
return content;
}
public FullHttpRequest getRequest() {
@Override
public final Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
String cookieString = request.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
return ServerCookieEncoder.STRICT.encode(cookies);
}
}
return Collections.emptyList();
}
@Override
public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpRequest.HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
}
}
@Override
public HttpRequest removeHeader(String header) {
HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove(header);
HttpHeaders trailingHeaders = new DefaultHttpHeaders();
trailingHeaders.add(request.trailingHeaders());
trailingHeaders.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(),
request.content(), headersWithoutContentTypeHeader, trailingHeaders);
return new NioHttpRequest(requestWithoutHeader, sequence);
}
@Override
public NioHttpResponse createResponse(RestStatus status, BytesReference content) {
return new NioHttpResponse(this, status, content);
}
public FullHttpRequest nettyRequest() {
return request;
}
int sequence() {
return sequence;
}
/**
* A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications
* and due to the underlying implementation, it performs case insensitive lookups of key to values.

View File

@ -19,19 +19,100 @@
package org.elasticsearch.http.nio;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedMessage;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestStatus;
public class NioHttpResponse extends HttpPipelinedMessage {
import java.util.Collections;
import java.util.EnumMap;
import java.util.Map;
private final FullHttpResponse response;
public class NioHttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage {
public NioHttpResponse(int sequence, FullHttpResponse response) {
super(sequence);
this.response = response;
private final int sequence;
private final NioHttpRequest request;
NioHttpResponse(NioHttpRequest request, RestStatus status, BytesReference content) {
super(request.nettyRequest().protocolVersion(), getStatus(status), ByteBufUtils.toByteBuf(content));
this.sequence = request.sequence();
this.request = request;
}
public FullHttpResponse getResponse() {
return response;
@Override
public void addHeader(String name, String value) {
headers().add(name, value);
}
@Override
public boolean containsHeader(String name) {
return headers().contains(name);
}
@Override
public int getSequence() {
return sequence;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
public NioHttpRequest getRequest() {
return request;
}
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
}
}

View File

@ -42,7 +42,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.AbstractHttpServerTransport;
import org.elasticsearch.http.BindHttpException;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpStats;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
@ -53,11 +52,11 @@ import org.elasticsearch.nio.EventHandler;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannel;
import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.rest.RestUtils;
import org.elasticsearch.threadpool.ThreadPool;
@ -104,12 +103,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
(s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2),
(s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope);
private final BigArrays bigArrays;
private final ThreadPool threadPool;
private final NamedXContentRegistry xContentRegistry;
private final HttpHandlingSettings httpHandlingSettings;
private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean reuseAddress;
@ -124,16 +117,12 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) {
super(settings, networkService, threadPool, dispatcher);
this.bigArrays = bigArrays;
this.threadPool = threadPool;
this.xContentRegistry = xContentRegistry;
super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings);
int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);;
this.corsConfig = buildCorsConfig(settings);
this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings);
@ -148,10 +137,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents);
}
BigArrays getBigArrays() {
return bigArrays;
}
public Logger getLogger() {
return logger;
}
@ -335,17 +320,17 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
socketChannels.add(socketChannel);
}
private class HttpChannelFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> {
private class HttpChannelFactory extends ChannelFactory<NioServerSocketChannel, NioHttpChannel> {
private HttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize));
}
@Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioSocketChannel nioChannel = new NioSocketChannel(channel);
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioHttpChannel nioChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this,
httpHandlingSettings, xContentRegistry, corsConfig, threadPool.getThreadContext());
handlingSettings, corsConfig);
Consumer<Exception> exceptionHandler = (e) -> exceptionCaught(nioChannel, e);
SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline,
InboundChannelBuffer.allocatingInstance());

View File

@ -22,6 +22,7 @@ package org.elasticsearch.http.nio.cors;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings;
import org.elasticsearch.http.nio.NioHttpResponse;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@ -76,6 +78,14 @@ public class NioCorsHandler extends ChannelDuplexHandler {
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass();
NioHttpResponse response = (NioHttpResponse) msg;
setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config);
ctx.write(response, promise);
}
public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) {
if (!config.isCorsSupportEnabled()) {
return;

View File

@ -23,29 +23,31 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
@ -55,6 +57,9 @@ import java.nio.ByteBuffer;
import java.util.List;
import java.util.function.BiConsumer;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL;
@ -64,7 +69,12 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEAD
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
@ -72,7 +82,7 @@ import static org.mockito.Mockito.verify;
public class HttpReadWriteHandlerTests extends ESTestCase {
private HttpReadWriteHandler handler;
private NioSocketChannel nioSocketChannel;
private NioHttpChannel nioHttpChannel;
private NioHttpServerTransport transport;
private final RequestEncoder requestEncoder = new RequestEncoder();
@ -96,15 +106,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings),
SETTING_PIPELINING_MAX_EVENTS.getDefault(settings),
SETTING_CORS_ENABLED.getDefault(settings));
ThreadContext threadContext = new ThreadContext(settings);
nioSocketChannel = mock(NioSocketChannel.class);
handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY,
NioCorsConfigBuilder.forAnyOrigin().build(), threadContext);
nioHttpChannel = mock(NioHttpChannel.class);
handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build());
}
public void testSuccessfulDecodeHttpRequest() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
ByteBuf buf = requestEncoder.encode(httpRequest);
int slicePoint = randomInt(buf.writerIndex() - 1);
@ -113,22 +121,21 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex());
handler.consumeReads(toChannelBuffer(slicedBuf));
verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class));
verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));
handler.consumeReads(toChannelBuffer(slicedBuf2));
ArgumentCaptor<RestRequest> requestCaptor = ArgumentCaptor.forClass(RestRequest.class);
verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class));
ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));
NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue();
FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest();
assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion());
assertEquals(httpRequest.method(), nettyHttpRequest.method());
HttpRequest nioHttpRequest = requestCaptor.getValue();
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
}
public void testDecodeHttpRequestError() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
ByteBuf buf = requestEncoder.encode(httpRequest);
buf.setByte(0, ' ');
@ -137,15 +144,15 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
handler.consumeReads(toChannelBuffer(buf));
ArgumentCaptor<Throwable> exceptionCaptor = ArgumentCaptor.forClass(Throwable.class);
verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture());
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());
assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
}
public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false);
io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false);
HttpUtil.setContentLength(httpRequest, 1025);
HttpUtil.setKeepAlive(httpRequest, false);
@ -153,60 +160,176 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
handler.consumeReads(toChannelBuffer(buf));
verify(transport, times(0)).dispatchBadRequest(any(), any(), any());
verify(transport, times(0)).dispatchRequest(any(), any());
verify(transport, times(0)).incomingRequestError(any(), any(), any());
verify(transport, times(0)).incomingRequest(any(), any());
List<FlushOperation> flushOperations = handler.pollFlushOperations();
assertFalse(flushOperations.isEmpty());
FlushOperation flushOperation = flushOperations.get(0);
HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());
flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been
// flushed
verify(nioSocketChannel).close();
verify(nioHttpChannel).close();
}
@SuppressWarnings("unchecked")
public void testEncodeHttpResponse() throws IOException {
prepareHandlerForResponse(handler);
FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
NioHttpResponse pipelinedResponse = new NioHttpResponse(0, fullHttpResponse);
DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0);
NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0");
SocketChannelContext context = mock(SocketChannelContext.class);
HttpWriteOperation writeOperation = new HttpWriteOperation(context, pipelinedResponse, mock(BiConsumer.class));
HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class));
List<FlushOperation> flushOperations = handler.writeToBytes(writeOperation);
HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite()));
FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite()));
assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
}
private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException {
HttpMethod method = HttpMethod.GET;
HttpVersion version = HttpVersion.HTTP_1_1;
public void testCorsEnabledWithoutAllowOrigins() throws IOException {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() throws IOException {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() throws IOException {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() throws IOException {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() throws IOException {
final String originValue = NioCorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig);
prepareHandlerForResponse(handler);
DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
NioHttpRequest nioHttpRequest = new NioHttpRequest(httpRequest, 0);
BytesArray content = new BytesArray("content");
HttpResponse response = nioHttpRequest.createResponse(RestStatus.OK, content);
response.addHeader("Content-Length", Integer.toString(content.length()));
SocketChannelContext context = mock(SocketChannelContext.class);
List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));
FlushOperation flushOperation = flushOperations.get(0);
return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
}
private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {
HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD;
HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1;
String uri = "http://localhost:9090/" + randomAlphaOfLength(8);
HttpRequest request = new DefaultFullHttpRequest(version, method, uri);
io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri);
ByteBuf buf = requestEncoder.encode(request);
handler.consumeReads(toChannelBuffer(buf));
ArgumentCaptor<RestRequest> requestCaptor = ArgumentCaptor.forClass(RestRequest.class);
verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class));
ArgumentCaptor<NioHttpRequest> requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class);
verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));
NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue();
FullHttpRequest requestParsed = nioHttpRequest.getRequest();
assertNotNull(requestParsed);
assertEquals(requestParsed.method(), method);
assertEquals(requestParsed.protocolVersion(), version);
assertEquals(requestParsed.uri(), uri);
return requestParsed;
NioHttpRequest nioHttpRequest = requestCaptor.getValue();
assertNotNull(nioHttpRequest);
assertEquals(method.name(), nioHttpRequest.method().name());
if (version == HttpVersion.HTTP_1_1) {
assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
} else {
assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion());
}
assertEquals(nioHttpRequest.uri(), uri);
return nioHttpRequest;
}
private InboundChannelBuffer toChannelBuffer(ByteBuf buf) {
@ -226,11 +349,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
return buffer;
}
private static final int MAX = 16 * 1024 * 1024;
private static class RequestEncoder {
private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder());
private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder(), new HttpObjectAggregator(MAX));
private ByteBuf encode(HttpRequest httpRequest) {
private ByteBuf encode(io.netty.handler.codec.http.HttpRequest httpRequest) {
requestEncoder.writeOutbound(httpRequest);
return requestEncoder.readOutbound();
}
@ -238,9 +363,9 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
private static class ResponseDecoder {
private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder());
private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(MAX));
private HttpResponse decode(ByteBuf response) {
private FullHttpResponse decode(ByteBuf response) {
responseDecoder.writeInbound(response);
return responseDecoder.readInbound();
}

View File

@ -1,349 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.nio;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.function.BiConsumer;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class NioHttpChannelTests extends ESTestCase {
private ThreadPool threadPool;
private MockBigArrays bigArrays;
private NioSocketChannel nioChannel;
private SocketChannelContext channelContext;
@Before
public void setup() throws Exception {
nioChannel = mock(NioSocketChannel.class);
channelContext = mock(SocketChannelContext.class);
when(nioChannel.getContext()).thenReturn(channelContext);
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(ByteBufUtils.toByteBuf(new TestResponse().content())));
}
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = NioCorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
public void testHeadersSet() {
Settings settings = Settings.builder().build();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote");
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
// send a response
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, corsConfig,
threadPool.getThreadContext());
TestResponse resp = new TestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(channelContext).sendMessage(responseCaptor.capture(), any());
Object nioResponse = responseCaptor.getValue();
HttpResponse response = ((NioHttpResponse) nioResponse).getResponse();
assertThat(response.headers().get("non-existent-header"), nullValue());
assertThat(response.headers().get(customHeader), equalTo(customHeaderValue));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length())));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType()));
}
@SuppressWarnings("unchecked")
public void testReleaseInListener() throws IOException {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final NioHttpRequest request = new NioHttpRequest(registry, httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings,
corsConfig, threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
Class<BiConsumer<Void, Exception>> listenerClass = (Class<BiConsumer<Void, Exception>>) (Class) BiConsumer.class;
ArgumentCaptor<BiConsumer<Void, Exception>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(channelContext).sendMessage(any(), listenerCaptor.capture());
BiConsumer<Void, Exception> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.accept(null, null);
} else {
listener.accept(null, new ClosedChannelException());
}
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
@SuppressWarnings("unchecked")
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
final FullHttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}
} else {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/");
if (!close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
}
}
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings,
corsConfig, threadPool.getThreadContext());
final TestResponse resp = new TestResponse();
channel.sendResponse(resp);
Class<BiConsumer<Void, Exception>> listenerClass = (Class<BiConsumer<Void, Exception>>) (Class) BiConsumer.class;
ArgumentCaptor<BiConsumer<Void, Exception>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(channelContext).sendMessage(any(), listenerCaptor.capture());
BiConsumer<Void, Exception> listener = listenerCaptor.getValue();
listener.accept(null, null);
if (close) {
verify(nioChannel, times(1)).close();
} else {
verify(nioChannel, times(0)).close();
}
}
private FullHttpResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, httpHandlingSettings, corsConfig,
threadPool.getThreadContext());
channel.sendResponse(new TestResponse());
// get the response
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(channelContext, atLeastOnce()).sendMessage(responseCaptor.capture(), any());
return ((NioHttpResponse) responseCaptor.getValue()).getResponse();
}
private static class TestResponse extends RestResponse {
private final BytesReference reference;
TestResponse() {
reference = ByteBufUtils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8));
}
@Override
public String contentType() {
return "text";
}
@Override
public BytesReference content() {
return reference;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -19,15 +19,12 @@
package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.junit.After;
@ -55,7 +55,6 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.hamcrest.core.Is.is;
@ -190,11 +189,11 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
ArrayList<ChannelPromise> promises = new ArrayList<>();
for (int i = 1; i < requests.size(); ++i) {
final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK);
ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise);
int sequence = requests.get(i).getSequence();
NioHttpResponse resp = new NioHttpResponse(sequence, httpResponse);
HttpPipelinedRequest<FullHttpRequest> pipelinedRequest = requests.get(i);
NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
NioHttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.writeAndFlush(resp, promise);
}
@ -231,10 +230,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
}
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<LastHttpContent>> {
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
@Override
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<LastHttpContent> pipelinedRequest) {
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> pipelinedRequest) {
LastHttpContent request = pipelinedRequest.getRequest();
final QueryStringDecoder decoder;
if (request instanceof FullHttpRequest) {
@ -244,9 +243,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
}
final String uri = decoder.path().replace("/", "");
final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8);
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content);
httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes());
final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8));
NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content);
httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length()));
final CountDownLatch waitingLatch = new CountDownLatch(1);
waitingRequests.put(uri, waitingLatch);
@ -258,7 +258,7 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
waitingLatch.await(1000, TimeUnit.SECONDS);
final ChannelPromise promise = ctx.newPromise();
eventLoopService.submit(() -> {
ctx.write(new NioHttpResponse(pipelinedRequest.getSequence(), httpResponse), promise);
ctx.write(httpResponse, promise);
finishingLatch.countDown();
});
} catch (InterruptedException e) {

View File

@ -280,40 +280,6 @@ public class NioHttpServerTransportTests extends ESTestCase {
assertThat(causeReference.get(), instanceOf(TooLongFrameException.class));
}
public void testDispatchDoesNotModifyThreadContext() throws InterruptedException {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (NioHttpServerTransport transport =
new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
transport.start();
transport.dispatchRequest(null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchBadRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
// public void testReadTimeout() throws Exception {
// final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
//

View File

@ -21,6 +21,7 @@ package org.elasticsearch.http;
import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntSet;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.network.NetworkService;
@ -29,7 +30,9 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.PortsRange;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.threadpool.ThreadPool;
@ -48,11 +51,14 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT;
public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport {
public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport {
public final HttpHandlingSettings handlingSettings;
protected final NetworkService networkService;
protected final BigArrays bigArrays;
protected final ThreadPool threadPool;
protected final Dispatcher dispatcher;
private final NamedXContentRegistry xContentRegistry;
protected final String[] bindHosts;
protected final String[] publishHosts;
@ -61,11 +67,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
protected volatile BoundTransportAddress boundAddress;
protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) {
protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) {
super(settings);
this.networkService = networkService;
this.bigArrays = bigArrays;
this.threadPool = threadPool;
this.xContentRegistry = xContentRegistry;
this.dispatcher = dispatcher;
this.handlingSettings = HttpHandlingSettings.fromSettings(settings);
// we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here
List<String> httpBindHost = SETTING_HTTP_BIND_HOST.get(settings);
@ -156,17 +166,94 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
return publishPort;
}
public void dispatchRequest(final RestRequest request, final RestChannel channel) {
/**
* This method handles an incoming http request.
*
* @param httpRequest that is incoming
* @param httpChannel that received the http request
*/
public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) {
handleIncomingRequest(httpRequest, httpChannel, null);
}
/**
* This method handles an incoming http request that has encountered an error.
*
* @param httpRequest that is incoming
* @param httpChannel that received the http request
* @param exception that was encountered
*/
public void incomingRequestError(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
handleIncomingRequest(httpRequest, httpChannel, exception);
}
// Visible for testing
void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
dispatcher.dispatchRequest(request, channel, threadContext);
if (badRequestCause != null) {
dispatcher.dispatchBadRequest(restRequest, channel, threadContext, badRequestCause);
} else {
dispatcher.dispatchRequest(restRequest, channel, threadContext);
}
}
}
public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) {
final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
dispatcher.dispatchBadRequest(request, channel, threadContext, cause);
private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
Exception badRequestCause = exception;
/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final RestRequest restRequest;
{
RestRequest innerRestRequest;
try {
innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
}
restRequest = innerRestRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these
* parameter values.
*/
final RestChannel channel;
{
RestChannel innerChannel;
ThreadContext threadContext = threadPool.getThreadContext();
try {
innerChannel = new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext);
} catch (final IllegalArgumentException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
innerChannel = new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext);
}
channel = innerChannel;
}
dispatchRequest(restRequest, channel, badRequestCause);
}
private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) {
HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type");
try {
return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel);
}
}
}

View File

@ -0,0 +1,172 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* The default rest channel for incoming requests. This class implements the basic logic for sending a rest
* response. It will set necessary headers nad ensure that bytes are released after the response is sent.
*/
public class DefaultRestChannel extends AbstractRestChannel implements RestChannel {
static final String CLOSE = "close";
static final String CONNECTION = "connection";
static final String KEEP_ALIVE = "keep-alive";
static final String CONTENT_TYPE = "content-type";
static final String CONTENT_LENGTH = "content-length";
static final String SET_COOKIE = "set-cookie";
static final String X_OPAQUE_ID = "X-Opaque-Id";
private final HttpRequest httpRequest;
private final BigArrays bigArrays;
private final HttpHandlingSettings settings;
private final ThreadContext threadContext;
private final HttpChannel httpChannel;
DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays,
HttpHandlingSettings settings, ThreadContext threadContext) {
super(request, settings.getDetailedErrorsEnabled());
this.httpChannel = httpChannel;
this.httpRequest = httpRequest;
this.bigArrays = bigArrays;
this.settings = settings;
this.threadContext = threadContext;
}
@Override
protected BytesStreamOutput newBytesOutput() {
return new ReleasableBytesStreamOutput(bigArrays);
}
@Override
public void sendResponse(RestResponse restResponse) {
HttpResponse httpResponse;
if (RestRequest.Method.HEAD == request.method()) {
httpResponse = httpRequest.createResponse(restResponse.status(), BytesArray.EMPTY);
} else {
httpResponse = httpRequest.createResponse(restResponse.status(), restResponse.content());
}
// TODO: Ideally we should move the setting of Cors headers into :server
// NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
String opaque = request.header(X_OPAQUE_ID);
if (opaque != null) {
setHeaderField(httpResponse, X_OPAQUE_ID, opaque);
}
// Add all custom headers
addCustomHeaders(httpResponse, restResponse.getHeaders());
addCustomHeaders(httpResponse, threadContext.getResponseHeaders());
ArrayList<Releasable> toClose = new ArrayList<>(3);
boolean success = false;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(httpResponse, CONTENT_TYPE, restResponse.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(httpResponse, CONTENT_LENGTH, String.valueOf(restResponse.content().length()), false);
addCookies(httpResponse);
BytesReference content = restResponse.content();
if (content instanceof Releasable) {
toClose.add((Releasable) content);
}
BytesStreamOutput bytesStreamOutput = bytesOutputOrNull();
if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) {
toClose.add((Releasable) bytesStreamOutput);
}
if (isCloseConnection()) {
toClose.add(httpChannel);
}
ActionListener<Void> listener = ActionListener.wrap(() -> Releasables.close(toClose));
httpChannel.sendResponse(httpResponse, listener);
success = true;
} finally {
if (success == false) {
Releasables.close(toClose);
}
}
}
private void setHeaderField(HttpResponse response, String headerField, String value) {
setHeaderField(response, headerField, value, true);
}
private void setHeaderField(HttpResponse response, String headerField, String value, boolean override) {
if (override || !response.containsHeader(headerField)) {
response.addHeader(headerField, value);
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
private void addCookies(HttpResponse response) {
if (settings.isResetCookies()) {
List<String> cookies = request.getHttpRequest().strictCookies();
if (cookies.isEmpty() == false) {
for (String cookie : cookies) {
response.addHeader(SET_COOKIE, cookie);
}
}
}
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION)));
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.lease.Releasable;
import java.net.InetSocketAddress;
public interface HttpChannel extends Releasable {
/**
* Sends a http response to the channel. The listener will be executed once the send process has been
* completed.
*
* @param response to send to channel
* @param listener to execute upon send completion
*/
void sendResponse(HttpResponse response, ActionListener<Void> listener);
/**
* Returns the local address for this channel.
*
* @return the local address of this channel.
*/
InetSocketAddress getLocalAddress();
/**
* Returns the remote address for this channel. Can be null if channel does not have a remote address.
*
* @return the remote address of this channel.
*/
InetSocketAddress getRemoteAddress();
/**
* Closes the channel. This might be an asynchronous process. There is no guarantee that the channel
* will be closed when this method returns.
*/
void close();
}

View File

@ -18,20 +18,17 @@
*/
package org.elasticsearch.http;
public class HttpPipelinedMessage implements Comparable<HttpPipelinedMessage> {
public interface HttpPipelinedMessage extends Comparable<HttpPipelinedMessage> {
private final int sequence;
public HttpPipelinedMessage(int sequence) {
this.sequence = sequence;
}
public int getSequence() {
return sequence;
}
/**
* Get the sequence number for this message.
*
* @return the sequence number
*/
int getSequence();
@Override
public int compareTo(HttpPipelinedMessage o) {
return Integer.compare(sequence, o.sequence);
default int compareTo(HttpPipelinedMessage o) {
return Integer.compare(getSequence(), o.getSequence());
}
}

View File

@ -18,15 +18,21 @@
*/
package org.elasticsearch.http;
public class HttpPipelinedRequest<R> extends HttpPipelinedMessage {
public class HttpPipelinedRequest<R> implements HttpPipelinedMessage {
private final R request;
private final int sequence;
HttpPipelinedRequest(int sequence, R request) {
super(sequence);
this.sequence = sequence;
this.request = request;
}
@Override
public int getSequence() {
return sequence;
}
public R getRequest() {
return request;
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.List;
import java.util.Map;
/**
* A basic http request abstraction. Http modules needs to implement this interface to integrate with the
* server package's rest handling.
*/
public interface HttpRequest {
enum HttpVersion {
HTTP_1_0,
HTTP_1_1
}
RestRequest.Method method();
/**
* The uri of the rest request, with the query string.
*/
String uri();
BytesReference content();
/**
* Get all of the headers and values associated with the headers. Modifications of this map are not supported.
*/
Map<String, List<String>> getHeaders();
List<String> strictCookies();
HttpVersion protocolVersion();
HttpRequest removeHeader(String header);
/**
* Create an http response from this request and the supplied status and content.
*/
HttpResponse createResponse(RestStatus status, BytesReference content);
}

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
/**
* A basic http response abstraction. Http modules must implement this interface as the server package rest
* handling needs to set http headers for a response.
*/
public interface HttpResponse {
void addHeader(String name, String value);
boolean containsHeader(String name);
}

View File

@ -40,7 +40,7 @@ public abstract class AbstractRestChannel implements RestChannel {
private static final Predicate<String> EXCLUDE_FILTER = INCLUDE_FILTER.negate();
protected final RestRequest request;
protected final boolean detailedErrorsEnabled;
private final boolean detailedErrorsEnabled;
private final String format;
private final String filterPath;
private final boolean pretty;

View File

@ -272,8 +272,9 @@ public class RestController extends AbstractComponent implements HttpServerTrans
*/
private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) {
if (restRequest.getXContentType() == null) {
if (restHandler.supportsContentStream() && restRequest.header("Content-Type") != null) {
final String lowercaseMediaType = restRequest.header("Content-Type").toLowerCase(Locale.ROOT);
String contentTypeHeader = restRequest.header("Content-Type");
if (restHandler.supportsContentStream() && contentTypeHeader != null) {
final String lowercaseMediaType = contentTypeHeader.toLowerCase(Locale.ROOT);
// we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/
if (lowercaseMediaType.equals("application/x-ndjson")) {
restRequest.setXContentType(XContentType.JSON);

View File

@ -35,10 +35,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import java.io.IOException;
import java.io.InputStream;
import java.net.SocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
@ -51,7 +52,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue;
import static org.elasticsearch.common.unit.TimeValue.parseTimeValue;
public abstract class RestRequest implements ToXContent.Params {
public class RestRequest implements ToXContent.Params {
// tchar pattern as defined by RFC7230 section 3.2.6
private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+");
@ -62,18 +63,47 @@ public abstract class RestRequest implements ToXContent.Params {
private final String rawPath;
private final Set<String> consumedParams = new HashSet<>();
private final SetOnce<XContentType> xContentType = new SetOnce<>();
private final HttpRequest httpRequest;
private final HttpChannel httpChannel;
protected RestRequest(NamedXContentRegistry xContentRegistry, Map<String, String> params, String path,
Map<String, List<String>> headers, HttpRequest httpRequest, HttpChannel httpChannel) {
final XContentType xContentType;
try {
xContentType = parseContentType(headers.get("Content-Type"));
} catch (final IllegalArgumentException e) {
throw new ContentTypeHeaderException(e);
}
if (xContentType != null) {
this.xContentType.set(xContentType);
}
this.xContentRegistry = xContentRegistry;
this.httpRequest = httpRequest;
this.httpChannel = httpChannel;
this.params = params;
this.rawPath = path;
this.headers = Collections.unmodifiableMap(headers);
}
protected RestRequest(RestRequest restRequest) {
this(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(),
restRequest.getHttpRequest(), restRequest.getHttpChannel());
}
/**
* Creates a new REST request.
* Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be
* decoded
*
* @param xContentRegistry the content registry
* @param uri the raw URI that will be parsed into the path and the parameters
* @param headers a map of the header; this map should implement a case-insensitive lookup
* @param httpRequest the http request
* @param httpChannel the http channel
* @throws BadParameterException if the parameters can not be decoded
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map<String, List<String>> headers) {
this(xContentRegistry, params(uri), path(uri), headers);
public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, HttpChannel httpChannel) {
Map<String, String> params = params(httpRequest.uri());
String path = path(httpRequest.uri());
return new RestRequest(xContentRegistry, params, path, httpRequest.getHeaders(), httpRequest, httpChannel);
}
private static Map<String, String> params(final String uri) {
@ -99,46 +129,34 @@ public abstract class RestRequest implements ToXContent.Params {
}
/**
* Creates a new REST request. In contrast to
* {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw
* a {@link BadParameterException}.
* Creates a new REST request. The path is not decoded so this constructor will not throw a
* {@link BadParameterException}.
*
* @param xContentRegistry the content registry
* @param params the request parameters
* @param path the raw path (which is not parsed)
* @param headers a map of the header; this map should implement a case-insensitive lookup
* @param httpRequest the http request
* @param httpChannel the http channel
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
public RestRequest(
final NamedXContentRegistry xContentRegistry,
final Map<String, String> params,
final String path,
final Map<String, List<String>> headers) {
final XContentType xContentType;
try {
xContentType = parseContentType(headers.get("Content-Type"));
} catch (final IllegalArgumentException e) {
throw new ContentTypeHeaderException(e);
}
if (xContentType != null) {
this.xContentType.set(xContentType);
}
this.xContentRegistry = xContentRegistry;
this.params = params;
this.rawPath = path;
this.headers = Collections.unmodifiableMap(headers);
public static RestRequest requestWithoutParameters(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest,
HttpChannel httpChannel) {
Map<String, String> params = Collections.emptyMap();
return new RestRequest(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel);
}
public enum Method {
GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT
}
public abstract Method method();
public Method method() {
return httpRequest.method();
}
/**
* The uri of the rest request, with the query string.
*/
public abstract String uri();
public String uri() {
return httpRequest.uri();
}
/**
* The non decoded, raw path provided.
@ -154,9 +172,13 @@ public abstract class RestRequest implements ToXContent.Params {
return RestUtils.decodeComponent(rawPath());
}
public abstract boolean hasContent();
public boolean hasContent() {
return content().length() > 0;
}
public abstract BytesReference content();
public BytesReference content() {
return httpRequest.content();
}
/**
* @return content of the request body or throw an exception if the body or content type is missing
@ -216,14 +238,12 @@ public abstract class RestRequest implements ToXContent.Params {
this.xContentType.set(xContentType);
}
@Nullable
public SocketAddress getRemoteAddress() {
return null;
public HttpChannel getHttpChannel() {
return httpChannel;
}
@Nullable
public SocketAddress getLocalAddress() {
return null;
public HttpRequest getHttpRequest() {
return httpRequest;
}
public final boolean hasParam(String key) {

View File

@ -20,10 +20,10 @@
package org.elasticsearch.rest;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -31,8 +31,7 @@ import java.util.Set;
public abstract class RestResponse {
protected Map<String, List<String>> customHeaders;
private Map<String, List<String>> customHeaders;
/**
* The response content type.
@ -81,10 +80,13 @@ public abstract class RestResponse {
}
/**
* Returns custom headers that have been added, or null if none have been set.
* Returns custom headers that have been added. This method should not be used to mutate headers.
*/
@Nullable
public Map<String, List<String>> getHeaders() {
return customHeaders;
if (customHeaders == null) {
return Collections.emptyMap();
} else {
return customHeaders;
}
}
}

View File

@ -19,13 +19,27 @@
package org.elasticsearch.http;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static java.net.InetAddress.getByName;
@ -36,6 +50,27 @@ import static org.hamcrest.Matchers.equalTo;
public class AbstractHttpServerTransportTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Collections.emptyList());
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
threadPool = null;
networkService = null;
bigArrays = null;
}
public void testHttpPublishPort() throws Exception {
int boundPort = randomIntBetween(9000, 9100);
int otherBoundPort = randomIntBetween(9200, 9300);
@ -71,6 +106,64 @@ public class AbstractHttpServerTransportTests extends ESTestCase {
}
}
public void testDispatchDoesNotModifyThreadContext() {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (AbstractHttpServerTransport transport =
new AbstractHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher) {
@Override
protected TransportAddress bindAddress(InetAddress hostAddress) {
return null;
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() throws IOException {
}
@Override
public HttpStats stats() {
return null;
}
}) {
transport.dispatchRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchRequest(null, null, new Exception());
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
private TransportAddress address(String host, int port) throws UnknownHostException {
return new TransportAddress(getByName(host), port);
}

View File

@ -0,0 +1,444 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class DefaultRestChannelTests extends ESTestCase {
private ThreadPool threadPool;
private MockBigArrays bigArrays;
private HttpChannel httpChannel;
@Before
public void setup() {
httpChannel = mock(HttpChannel.class);
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final TestResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(new TestRestResponse().content()));
}
// TODO: Enable these Cors tests when the Cors logic lives in :server
// public void testCorsEnabledWithoutAllowOrigins() {
// // Set up a HTTP transport with only the CORS enabled setting
// Settings settings = Settings.builder()
// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
// }
//
// public void testCorsEnabledWithAllowOrigins() {
// final String originValue = "remote-host";
// // create a http transport with CORS enabled and allow origin configured
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testCorsAllowOriginWithSameHost() {
// String originValue = "remote-host";
// String host = "remote-host";
// // create a http transport with CORS enabled
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, host);
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = "http://" + originValue;
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue + ":5555";
// host = host + ":5555";
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue.replace("http", "https");
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testThatStringLiteralWorksOnMatch() {
// final String originValue = "remote-host";
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
// }
//
// public void testThatAnyOriginWorks() {
// final String originValue = NioCorsHandler.ANY_ORIGIN;
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
// }
public void testHeadersSet() {
Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
TestRestResponse resp = new TestRestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse httpResponse = responseCaptor.getValue();
Map<String, List<String>> headers = httpResponse.headers;
assertNull(headers.get("non-existent-header"));
assertEquals(customHeaderValue, headers.get(customHeader).get(0));
assertEquals("abc", headers.get(DefaultRestChannel.X_OPAQUE_ID).get(0));
assertEquals(Integer.toString(resp.content().length()), headers.get(DefaultRestChannel.CONTENT_LENGTH).get(0));
assertEquals(resp.contentType(), headers.get(DefaultRestChannel.CONTENT_TYPE).get(0));
}
public void testCookiesSet() {
Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse nioResponse = responseCaptor.getValue();
Map<String, List<String>> headers = nioResponse.headers;
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie"));
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2"));
}
@SuppressWarnings("unchecked")
public void testReleaseInListener() throws IOException {
final Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(httpChannel).sendResponse(any(), listenerCaptor.capture());
ActionListener<Void> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new ClosedChannelException());
}
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
@SuppressWarnings("unchecked")
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
final HttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
if (close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE));
}
} else {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
if (!close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE));
}
}
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(httpChannel).sendResponse(any(), listenerCaptor.capture());
ActionListener<Void> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new ClosedChannelException());
}
if (close) {
verify(httpChannel, times(1)).close();
} else {
verify(httpChannel, times(0)).close();
}
}
private TestResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private TestResponse executeRequest(final Settings settings, final String originValue, final String host) {
HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
// TODO: These exist for the Cors tests
// if (originValue != null) {
// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
// }
// httpRequest.headers().add(HttpHeaderNames.HOST, host);
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
// get the response
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any());
return responseCaptor.getValue();
}
private static class TestRequest implements HttpRequest {
private final HttpVersion version;
private final RestRequest.Method method;
private final String uri;
private HashMap<String, List<String>> headers = new HashMap<>();
private TestRequest(HttpVersion version, RestRequest.Method method, String uri) {
this.version = version;
this.method = method;
this.uri = uri;
}
@Override
public RestRequest.Method method() {
return method;
}
@Override
public String uri() {
return uri;
}
@Override
public BytesReference content() {
return BytesArray.EMPTY;
}
@Override
public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Arrays.asList("cookie", "cookie2");
}
@Override
public HttpVersion protocolVersion() {
return version;
}
@Override
public HttpRequest removeHeader(String header) {
throw new UnsupportedOperationException("Do not support removing header on test request.");
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return new TestResponse(status, content);
}
}
private static class TestResponse implements HttpResponse {
private final RestStatus status;
private final BytesReference content;
private final Map<String, List<String>> headers = new HashMap<>();
TestResponse(RestStatus status, BytesReference content) {
this.status = status;
this.content = content;
}
public String contentType() {
return "text";
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return status;
}
@Override
public void addHeader(String name, String value) {
if (headers.containsKey(name) == false) {
ArrayList<String> values = new ArrayList<>();
values.add(value);
headers.put(name, values);
} else {
headers.get(name).add(value);
}
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
}
private static class TestRestResponse extends RestResponse {
private final BytesReference content;
TestRestResponse() {
content = new BytesArray("content".getBytes(StandardCharsets.UTF_8));
}
public String contentType() {
return "text";
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -29,7 +29,6 @@ import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
@ -165,28 +164,7 @@ public class BytesRestResponseTests extends ESTestCase {
public void testResponseWhenPathContainsEncodingError() throws IOException {
final String path = "%a";
final RestRequest request =
new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) {
@Override
public Method method() {
return null;
}
@Override
public String uri() {
return null;
}
@Override
public boolean hasContent() {
return false;
}
@Override
public BytesReference content() {
return null;
}
};
final RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath(path).build();
final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath()));
final RestChannel channel = new DetailedExceptionRestChannel(request);
// if we try to decode the path, this will throw an IllegalArgumentException again

View File

@ -110,21 +110,21 @@ public class RestControllerTests extends ESTestCase {
RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
final RestController spyRestController = spy(restController);
when(spyRestController.getAllHandlers(fakeRequest))
.thenReturn(new Iterator<MethodHandlers>() {
@Override
public boolean hasNext() {
return false;
}
.thenReturn(new Iterator<MethodHandlers>() {
@Override
public boolean hasNext() {
return false;
}
@Override
public MethodHandlers next() {
return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
}, RestRequest.Method.GET);
}
});
@Override
public MethodHandlers next() {
return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
}, RestRequest.Method.GET);
}
});
AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
restController.dispatchRequest(fakeRequest, channel, threadContext);
// the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the
@ -136,7 +136,7 @@ public class RestControllerTests extends ESTestCase {
public void testCanTripCircuitBreaker() throws Exception {
RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null, null, circuitBreakerService,
usageService);
usageService);
// trip circuit breaker by default
controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));
@ -209,7 +209,7 @@ public class RestControllerTests extends ESTestCase {
return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true);
};
final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null,
circuitBreakerService, usageService);
circuitBreakerService, usageService);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, Optional.of(handler));
assertTrue(wrapperCalled.get());
@ -240,7 +240,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestAddsAndFreesBytesOnSuccess() {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON);
RestRequest request = testRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -252,7 +252,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestAddsAndFreesBytesOnError() {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON);
RestRequest request = testRestRequest("/error", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -265,7 +265,7 @@ public class RestControllerTests extends ESTestCase {
int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength);
// we will produce an error in the rest handler and one more when sending the error response
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON);
RestRequest request = testRestRequest("/error", content, XContentType.JSON);
ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -277,7 +277,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestLimitsBytes() {
int contentLength = BREAKER_LIMIT.bytesAsInt() + 1;
String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON);
RestRequest request = testRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -288,11 +288,11 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequiresContentTypeForRequestsWithContent() {
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
TestRestRequest request = new TestRestRequest("/", content, null);
RestRequest request = testRestRequest("/", content, null);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE);
restController = new RestController(
Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(),
Collections.emptySet(), null, null, circuitBreakerService, usageService);
Collections.emptySet(), null, null, circuitBreakerService, usageService);
restController.registerHandler(RestRequest.Method.GET, "/",
(r, c, client) -> c.sendResponse(
new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));
@ -412,8 +412,8 @@ public class RestControllerTests extends ESTestCase {
public void testNonStreamingXContentCausesErrorResponse() throws IOException {
FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY)
.withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()),
XContentType.YAML).withPath("/foo").build();
.withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()),
XContentType.YAML).withPath("/foo").build();
AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE);
restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() {
@Override
@ -457,10 +457,10 @@ public class RestControllerTests extends ESTestCase {
final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build();
final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST);
restController.dispatchBadRequest(
fakeRestRequest,
channel,
new ThreadContext(Settings.EMPTY),
randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request"));
fakeRestRequest,
channel,
new ThreadContext(Settings.EMPTY),
randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request"));
assertTrue(channel.getSendResponseCalled());
assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request"));
}
@ -495,7 +495,7 @@ public class RestControllerTests extends ESTestCase {
@Override
public BoundTransportAddress boundAddress() {
TransportAddress transportAddress = buildNewFakeTransportAddress();
return new BoundTransportAddress(new TransportAddress[] {transportAddress} ,transportAddress);
return new BoundTransportAddress(new TransportAddress[]{transportAddress}, transportAddress);
}
@Override
@ -547,35 +547,11 @@ public class RestControllerTests extends ESTestCase {
}
}
private static final class TestRestRequest extends RestRequest {
private final BytesReference content;
private TestRestRequest(String path, String content, XContentType xContentType) {
super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ?
Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType())));
this.content = new BytesArray(content);
}
@Override
public Method method() {
return Method.GET;
}
@Override
public String uri() {
return null;
}
@Override
public boolean hasContent() {
return true;
}
@Override
public BytesReference content() {
return content;
}
private static RestRequest testRestRequest(String path, String content, XContentType xContentType) {
FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY);
builder.withPath(path);
builder.withContent(new BytesArray(content), xContentType);
return builder.build();
}
}

View File

@ -27,6 +27,7 @@ import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import java.io.IOException;
import java.util.ArrayList;
@ -44,66 +45,66 @@ import static org.hamcrest.Matchers.instanceOf;
public class RestRequestTests extends ESTestCase {
public void testContentParser() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentParser());
contentRestRequest("", emptyMap()).contentParser());
assertEquals("request body is required", e.getMessage());
e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", singletonMap("source", "{}")).contentParser());
contentRestRequest("", singletonMap("source", "{}")).contentParser());
assertEquals("request body is required", e.getMessage());
assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentParser().map());
assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentParser().map());
e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap(), emptyMap()).contentParser());
contentRestRequest("", emptyMap(), emptyMap()).contentParser());
assertEquals("request body is required", e.getMessage());
}
public void testApplyContentParser() throws IOException {
new ContentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called"));
new ContentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called"));
contentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called"));
contentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called"));
AtomicReference<Object> source = new AtomicReference<>();
new ContentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map()));
contentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map()));
assertEquals(emptyMap(), source.get());
}
public void testContentOrSourceParam() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentOrSourceParam());
contentRestRequest("", emptyMap()).contentOrSourceParam());
assertEquals("request body or source parameter is required", e.getMessage());
assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2());
assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2());
assertEquals(new BytesArray("stuff"),
new ContentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
contentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2());
assertEquals(new BytesArray("{\"foo\": \"stuff\"}"),
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder()
contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap())
.contentOrSourceParam().v2());
e = expectThrows(IllegalStateException.class, () ->
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder()
contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").immutableMap()).contentOrSourceParam());
assertEquals("source and source_content_type parameters are required", e.getMessage());
}
public void testHasContentOrSourceParam() throws IOException {
assertEquals(false, new ContentRestRequest("", emptyMap()).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("stuff", emptyMap()).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam());
assertEquals(false, contentRestRequest("", emptyMap()).hasContentOrSourceParam());
assertEquals(true, contentRestRequest("stuff", emptyMap()).hasContentOrSourceParam());
assertEquals(true, contentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam());
assertEquals(true, contentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam());
}
public void testContentOrSourceParamParser() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentOrSourceParamParser());
contentRestRequest("", emptyMap()).contentOrSourceParamParser());
assertEquals("request body or source parameter is required", e.getMessage());
assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map());
assertEquals(emptyMap(), new ContentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map());
assertEquals(emptyMap(), new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder()
assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map());
assertEquals(emptyMap(), contentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map());
assertEquals(emptyMap(), contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map());
}
public void testWithContentOrSourceParamParserOrNull() throws IOException {
new ContentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser));
new ContentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map()));
new ContentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser ->
contentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser));
contentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map()));
contentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser ->
assertEquals(emptyMap(), parser.map()));
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder().put("source_content_type", "application/json")
contentRestRequest("", MapBuilder.<String, String>newMapBuilder().put("source_content_type", "application/json")
.put("source", "{}").immutableMap())
.withContentOrSourceParamParserOrNull(parser ->
assertEquals(emptyMap(), parser.map()));
@ -113,18 +114,18 @@ public class RestRequestTests extends ESTestCase {
for (XContentType xContentType : XContentType.values()) {
Map<String, List<String>> map = new HashMap<>();
map.put("Content-Type", Collections.singletonList(xContentType.mediaType()));
ContentRestRequest restRequest = new ContentRestRequest("", Collections.emptyMap(), map);
RestRequest restRequest = contentRestRequest("", Collections.emptyMap(), map);
assertEquals(xContentType, restRequest.getXContentType());
map = new HashMap<>();
map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters()));
restRequest = new ContentRestRequest("", Collections.emptyMap(), map);
restRequest = contentRestRequest("", Collections.emptyMap(), map);
assertEquals(xContentType, restRequest.getXContentType());
}
}
public void testPlainTextSupport() {
ContentRestRequest restRequest = new ContentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(),
RestRequest restRequest = contentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(),
Collections.singletonMap("Content-Type",
Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8"))));
assertNull(restRequest.getXContentType());
@ -136,7 +137,7 @@ public class RestRequestTests extends ESTestCase {
RestRequest.ContentTypeHeaderException.class,
() -> {
final Map<String, List<String>> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type));
new ContentRestRequest("", Collections.emptyMap(), headers);
contentRestRequest("", Collections.emptyMap(), headers);
});
assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
@ -144,7 +145,7 @@ public class RestRequestTests extends ESTestCase {
}
public void testNoContentTypeHeader() {
ContentRestRequest contentRestRequest = new ContentRestRequest("", Collections.emptyMap(), Collections.emptyMap());
RestRequest contentRestRequest = contentRestRequest("", Collections.emptyMap(), Collections.emptyMap());
assertNull(contentRestRequest.getXContentType());
}
@ -152,7 +153,7 @@ public class RestRequestTests extends ESTestCase {
List<String> headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10)));
final RestRequest.ContentTypeHeaderException e = expectThrows(
RestRequest.ContentTypeHeaderException.class,
() -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers)));
() -> contentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers)));
assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf((IllegalArgumentException.class)));
assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided"));
@ -160,52 +161,64 @@ public class RestRequestTests extends ESTestCase {
public void testRequiredContent() {
Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).requiredContent());
contentRestRequest("", emptyMap()).requiredContent());
assertEquals("request body is required", e.getMessage());
assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).requiredContent());
assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).requiredContent());
assertEquals(new BytesArray("stuff"),
new ContentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
contentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent());
e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder()
contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap())
.requiredContent());
assertEquals("request body is required", e.getMessage());
e = expectThrows(IllegalStateException.class, () ->
new ContentRestRequest("test", null, Collections.emptyMap()).requiredContent());
contentRestRequest("test", null, Collections.emptyMap()).requiredContent());
assertEquals("unknown content type", e.getMessage());
}
private static RestRequest contentRestRequest(String content, Map<String, String> params) {
Map<String, List<String>> headers = new HashMap<>();
headers.put("Content-Type", Collections.singletonList("application/json"));
return contentRestRequest(content, params, headers);
}
private static RestRequest contentRestRequest(String content, Map<String, String> params, Map<String, List<String>> headers) {
FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY);
builder.withHeaders(headers);
builder.withContent(new BytesArray(content), null);
builder.withParams(params);
return new ContentRestRequest(builder.build());
}
private static final class ContentRestRequest extends RestRequest {
private final BytesArray content;
ContentRestRequest(String content, Map<String, String> params) {
this(content, params, Collections.singletonMap("Content-Type", Collections.singletonList("application/json")));
}
private final RestRequest restRequest;
ContentRestRequest(String content, Map<String, String> params, Map<String, List<String>> headers) {
super(NamedXContentRegistry.EMPTY, params, "not used by this test", headers);
this.content = new BytesArray(content);
}
@Override
public boolean hasContent() {
return Strings.hasLength(content);
}
@Override
public BytesReference content() {
return content;
}
@Override
public String uri() {
throw new UnsupportedOperationException("Not used by this test");
private ContentRestRequest(RestRequest restRequest) {
super(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(),
restRequest.getHttpRequest(), restRequest.getHttpChannel());
this.restRequest = restRequest;
}
@Override
public Method method() {
throw new UnsupportedOperationException("Not used by this test");
return restRequest.method();
}
@Override
public String uri() {
return restRequest.uri();
}
@Override
public boolean hasContent() {
return Strings.hasLength(content());
}
@Override
public BytesReference content() {
return restRequest.content();
}
}
}

View File

@ -19,12 +19,18 @@
package org.elasticsearch.test.rest;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.net.SocketAddress;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -32,45 +38,115 @@ import java.util.Map;
public class FakeRestRequest extends RestRequest {
private final BytesReference content;
private final Method method;
private final SocketAddress remoteAddress;
public FakeRestRequest() {
this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null);
this(NamedXContentRegistry.EMPTY, new FakeHttpRequest(Method.GET, "", BytesArray.EMPTY, new HashMap<>()), new HashMap<>(),
new FakeHttpChannel(null));
}
private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map<String, List<String>> headers,
Map<String, String> params, BytesReference content, Method method, String path, SocketAddress remoteAddress) {
super(xContentRegistry, params, path, headers);
this.content = content;
this.method = method;
this.remoteAddress = remoteAddress;
}
@Override
public Method method() {
return method;
}
@Override
public String uri() {
return rawPath();
private FakeRestRequest(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, Map<String, String> params,
HttpChannel httpChannel) {
super(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel);
}
@Override
public boolean hasContent() {
return content != null;
return content() != null;
}
@Override
public BytesReference content() {
return content;
private static class FakeHttpRequest implements HttpRequest {
private final Method method;
private final String uri;
private final BytesReference content;
private final Map<String, List<String>> headers;
private FakeHttpRequest(Method method, String uri, BytesReference content, Map<String, List<String>> headers) {
this.method = method;
this.uri = uri;
this.content = content;
this.headers = headers;
}
@Override
public Method method() {
return method;
}
@Override
public String uri() {
return uri;
}
@Override
public BytesReference content() {
return content;
}
@Override
public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Collections.emptyList();
}
@Override
public HttpVersion protocolVersion() {
return HttpVersion.HTTP_1_1;
}
@Override
public HttpRequest removeHeader(String header) {
headers.remove(header);
return this;
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
Map<String, String> headers = new HashMap<>();
return new HttpResponse() {
@Override
public void addHeader(String name, String value) {
headers.put(name, value);
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
};
}
}
@Override
public SocketAddress getRemoteAddress() {
return remoteAddress;
private static class FakeHttpChannel implements HttpChannel {
private final InetSocketAddress remoteAddress;
private FakeHttpChannel(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return remoteAddress;
}
@Override
public void close() {
}
}
public static class Builder {
@ -86,7 +162,7 @@ public class FakeRestRequest extends RestRequest {
private Method method = Method.GET;
private SocketAddress address = null;
private InetSocketAddress address = null;
public Builder(NamedXContentRegistry xContentRegistry) {
this.xContentRegistry = xContentRegistry;
@ -120,15 +196,14 @@ public class FakeRestRequest extends RestRequest {
return this;
}
public Builder withRemoteAddress(SocketAddress address) {
public Builder withRemoteAddress(InetSocketAddress address) {
this.address = address;
return this;
}
public FakeRestRequest build() {
return new FakeRestRequest(xContentRegistry, headers, params, content, method, path, address);
FakeHttpRequest fakeHttpRequest = new FakeHttpRequest(method, path, content, headers);
return new FakeRestRequest(xContentRegistry, fakeHttpRequest, params, new FakeHttpChannel(address));
}
}
}

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.core.security.rest;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple;
@ -17,7 +16,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.rest.RestRequest;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.Map;
import java.util.Set;
@ -33,37 +31,15 @@ public interface RestRequestFilter {
default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException {
Set<String> fields = getFilteredFields();
if (restRequest.hasContent() && fields.isEmpty() == false) {
return new RestRequest(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders()) {
return new RestRequest(restRequest) {
private BytesReference filteredBytes = null;
@Override
public Method method() {
return restRequest.method();
}
@Override
public String uri() {
return restRequest.uri();
}
@Override
public boolean hasContent() {
return true;
}
@Nullable
@Override
public SocketAddress getRemoteAddress() {
return restRequest.getRemoteAddress();
}
@Nullable
@Override
public SocketAddress getLocalAddress() {
return restRequest.getLocalAddress();
}
@Override
public BytesReference content() {
if (filteredBytes == null) {

View File

@ -69,7 +69,6 @@ import java.io.Closeable;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
@ -829,10 +828,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
msg.builder.field(Field.REQUEST_BODY, restRequestContent(request));
}
msg.builder.field(Field.ORIGIN_TYPE, "rest");
SocketAddress address = request.getRemoteAddress();
if (address instanceof InetSocketAddress) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress())
.getAddress()));
InetSocketAddress address = request.getHttpChannel().getRemoteAddress();
if (address != null) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress()));
} else {
msg.builder.field(Field.ORIGIN_ADDRESS, address);
}
@ -854,10 +852,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
msg.builder.field(Field.REQUEST_BODY, restRequestContent(request));
}
msg.builder.field(Field.ORIGIN_TYPE, "rest");
SocketAddress address = request.getRemoteAddress();
if (address instanceof InetSocketAddress) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress())
.getAddress()));
InetSocketAddress address = request.getHttpChannel().getRemoteAddress();
if (address != null) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress()));
} else {
msg.builder.field(Field.ORIGIN_ADDRESS, address);
}

View File

@ -38,7 +38,6 @@ import org.elasticsearch.xpack.security.transport.filter.SecurityIpFilterRule;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
@ -544,13 +543,8 @@ public class LoggingAuditTrail extends AbstractComponent implements AuditTrail,
}
private static String hostAttributes(RestRequest request) {
String formattedAddress;
final SocketAddress socketAddress = request.getRemoteAddress();
if (socketAddress instanceof InetSocketAddress) {
formattedAddress = NetworkAddress.format(((InetSocketAddress) socketAddress).getAddress());
} else {
formattedAddress = socketAddress.toString();
}
final InetSocketAddress socketAddress = request.getHttpChannel().getRemoteAddress();
String formattedAddress = NetworkAddress.format(socketAddress.getAddress());
return "origin_address=[" + formattedAddress + "]";
}

View File

@ -20,7 +20,7 @@ public class RemoteHostHeader {
* then be copied to the subsequent action requests.
*/
public static void process(RestRequest request, ThreadContext threadContext) {
threadContext.putTransient(KEY, request.getRemoteAddress());
threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress());
}
/**

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.security.rest;
import io.netty.channel.Channel;
import io.netty.handler.ssl.SslHandler;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
@ -13,7 +14,8 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.netty4.Netty4HttpRequest;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.netty4.Netty4HttpChannel;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
@ -50,10 +52,11 @@ public class SecurityRestFilter implements RestHandler {
if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) {
// CORS - allow for preflight unauthenticated OPTIONS request
if (extractClientCertificate) {
Netty4HttpRequest nettyHttpRequest = (Netty4HttpRequest) request;
SslHandler handler = nettyHttpRequest.getChannel().pipeline().get(SslHandler.class);
HttpChannel httpChannel = request.getHttpChannel();
Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel();
SslHandler handler = nettyChannel.pipeline().get(SslHandler.class);
assert handler != null;
ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel());
ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyChannel);
}
service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap(
authentication -> {

View File

@ -104,7 +104,7 @@ public class SecurityNetty4HttpServerTransport extends Netty4HttpServerTransport
private final class HttpSslChannelHandler extends HttpChannelHandler {
HttpSslChannelHandler() {
super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext());
super(SecurityNetty4HttpServerTransport.this, handlingSettings);
}
@Override

View File

@ -33,6 +33,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.plugins.MetaDataUpgrader;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestRequest;
@ -914,7 +915,9 @@ public class IndexAuditTrailTests extends SecurityIntegTestCase {
private RestRequest mockRestRequest() {
RestRequest request = mock(RestRequest.class);
when(request.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200));
HttpChannel httpChannel = mock(HttpChannel.class);
when(request.getHttpChannel()).thenReturn(httpChannel);
when(httpChannel.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200));
when(request.uri()).thenReturn("_uri");
return request;
}

View File

@ -88,6 +88,6 @@ public class RestRequestFilterTests extends ESTestCase {
new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON)
.withRemoteAddress(address).build();
RestRequest filtered = filter.getFilteredRequest(restRequest);
assertEquals(address, filtered.getRemoteAddress());
assertEquals(address, filtered.getHttpChannel().getRemoteAddress());
}
}

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
@ -67,6 +68,7 @@ public class SecurityRestFilterTests extends ESTestCase {
public void testProcess() throws Exception {
RestRequest request = mock(RestRequest.class);
when(request.getHttpChannel()).thenReturn(mock(HttpChannel.class));
Authentication authentication = mock(Authentication.class);
doAnswer((i) -> {
ActionListener callback =