Merge branch 'master' into index-lifecycle

This commit is contained in:
Tal Levy 2018-06-14 15:32:41 -07:00
commit 2af05e5480
77 changed files with 2471 additions and 2377 deletions

View File

@ -53,9 +53,23 @@ subprojects {
description = "Elasticsearch subproject ${project.path}" description = "Elasticsearch subproject ${project.path}"
} }
apply plugin: 'nebula.info-scm'
String licenseCommit
if (VersionProperties.elasticsearch.toString().endsWith('-SNAPSHOT')) {
licenseCommit = scminfo.change ?: "master" // leniency for non git builds
} else {
licenseCommit = "v${version}"
}
String elasticLicenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt"
subprojects { subprojects {
// Default to the apache license
project.ext.licenseName = 'The Apache Software License, Version 2.0' project.ext.licenseName = 'The Apache Software License, Version 2.0'
project.ext.licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt' project.ext.licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt'
// But stick the Elastic license url in project.ext so we can get it if we need to switch to it
project.ext.elasticLicenseUrl = elasticLicenseUrl
// we only use maven publish to add tasks for pom generation // we only use maven publish to add tasks for pom generation
plugins.withType(MavenPublishPlugin).whenPluginAdded { plugins.withType(MavenPublishPlugin).whenPluginAdded {
publishing { publishing {

View File

@ -228,6 +228,8 @@ subprojects {
check.dependsOn checkNotice check.dependsOn checkNotice
if (project.name == 'zip' || project.name == 'tar') { if (project.name == 'zip' || project.name == 'tar') {
project.ext.licenseName = 'Elastic License'
project.ext.licenseUrl = ext.elasticLicenseUrl
task checkMlCppNotice { task checkMlCppNotice {
dependsOn buildDist, checkExtraction dependsOn buildDist, checkExtraction
onlyIf toolExists onlyIf toolExists

View File

@ -19,252 +19,58 @@
package org.elasticsearch.http.netty4; package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse; import org.elasticsearch.action.ActionListener;
import io.netty.handler.codec.http.FullHttpRequest; import org.elasticsearch.http.HttpChannel;
import io.netty.handler.codec.http.FullHttpResponse; import org.elasticsearch.http.HttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.transport.netty4.Netty4Utils;
import java.util.Collections; import java.net.InetSocketAddress;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
final class Netty4HttpChannel extends AbstractRestChannel { public class Netty4HttpChannel implements HttpChannel {
private final Netty4HttpServerTransport transport;
private final Channel channel; private final Channel channel;
private final FullHttpRequest nettyRequest;
private final int sequence;
private final ThreadContext threadContext;
private final HttpHandlingSettings handlingSettings;
/** Netty4HttpChannel(Channel channel) {
* @param transport The corresponding <code>NettyHttpServerTransport</code> where this channel belongs to. this.channel = channel;
* @param request The request that is handled by this channel.
* @param sequence The pipelining sequence number for this request
* @param handlingSettings true if error messages should include stack traces.
* @param threadContext the thread context for the channel
*/
Netty4HttpChannel(Netty4HttpServerTransport transport, Netty4HttpRequest request, int sequence, HttpHandlingSettings handlingSettings,
ThreadContext threadContext) {
super(request, handlingSettings.getDetailedErrorsEnabled());
this.transport = transport;
this.channel = request.getChannel();
this.nettyRequest = request.request();
this.sequence = sequence;
this.threadContext = threadContext;
this.handlingSettings = handlingSettings;
} }
@Override @Override
protected BytesStreamOutput newBytesOutput() { public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
return new ReleasableBytesStreamOutput(transport.bigArrays); ChannelPromise writePromise = channel.newPromise();
writePromise.addListener(f -> {
if (f.isSuccess()) {
listener.onResponse(null);
} else {
final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause);
if (cause instanceof Error) {
listener.onFailure(new Exception(cause));
} else {
listener.onFailure((Exception) cause);
}
}
});
channel.writeAndFlush(response, writePromise);
} }
@Override @Override
public void sendResponse(RestResponse response) { public InetSocketAddress getLocalAddress() {
// if the response object was created upstream, then use it; return (InetSocketAddress) channel.localAddress();
// otherwise, create a new one
ByteBuf buffer = Netty4Utils.toByteBuf(response.content());
final FullHttpResponse resp;
if (HttpMethod.HEAD.equals(nettyRequest.method())) {
resp = newResponse(Unpooled.EMPTY_BUFFER);
} else {
resp = newResponse(buffer);
}
resp.setStatus(getStatus(response.status()));
Netty4CorsHandler.setCorsResponseHeaders(nettyRequest, resp, transport.getCorsConfig());
String opaque = nettyRequest.headers().get("X-Opaque-Id");
if (opaque != null) {
setHeaderField(resp, "X-Opaque-Id", opaque);
} }
// Add all custom headers @Override
addCustomHeaders(resp, response.getHeaders()); public InetSocketAddress getRemoteAddress() {
addCustomHeaders(resp, threadContext.getResponseHeaders()); return (InetSocketAddress) channel.remoteAddress();
BytesReference content = response.content();
boolean releaseContent = content instanceof Releasable;
boolean releaseBytesStreamOutput = bytesOutputOrNull() instanceof ReleasableBytesStreamOutput;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false);
addCookies(resp);
final ChannelPromise promise = channel.newPromise();
if (releaseContent) {
promise.addListener(f -> ((Releasable) content).close());
} }
if (releaseBytesStreamOutput) { @Override
promise.addListener(f -> bytesOutputOrNull().close()); public void close() {
channel.close();
} }
if (isCloseConnection()) { public Channel getNettyChannel() {
promise.addListener(ChannelFutureListener.CLOSE); return channel;
}
Netty4HttpResponse newResponse = new Netty4HttpResponse(sequence, resp);
channel.writeAndFlush(newResponse, promise);
releaseContent = false;
releaseBytesStreamOutput = false;
} finally {
if (releaseContent) {
((Releasable) content).close();
}
if (releaseBytesStreamOutput) {
bytesOutputOrNull().close();
}
}
}
private void setHeaderField(HttpResponse resp, String headerField, String value) {
setHeaderField(resp, headerField, value, true);
}
private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) {
if (override || !resp.headers().contains(headerField)) {
resp.headers().add(headerField, value);
}
}
private void addCookies(HttpResponse resp) {
if (handlingSettings.isResetCookies()) {
String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<io.netty.handler.codec.http.cookie.Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
// Reset the cookies if necessary.
resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies));
}
}
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0);
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) ||
(http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)));
}
// Create a new {@link HttpResponse} to transmit the response for the netty request.
private FullHttpResponse newResponse(ByteBuf buffer) {
final boolean http10 = isHttp10();
final boolean close = isCloseConnection();
// Build the response object.
final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize
final FullHttpResponse response;
if (http10) {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer);
if (!close) {
response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive");
}
} else {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer);
}
return response;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
} }
} }

View File

@ -66,7 +66,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
try { try {
List<Tuple<Netty4HttpResponse, ChannelPromise>> readyResponses = aggregator.write(response, promise); List<Tuple<Netty4HttpResponse, ChannelPromise>> readyResponses = aggregator.write(response, promise);
for (Tuple<Netty4HttpResponse, ChannelPromise> readyResponse : readyResponses) { for (Tuple<Netty4HttpResponse, ChannelPromise> readyResponse : readyResponses) {
ctx.write(readyResponse.v1().getResponse(), readyResponse.v2()); ctx.write(readyResponse.v1(), readyResponse.v2());
} }
success = true; success = true;
} catch (IllegalStateException e) { } catch (IllegalStateException e) {

View File

@ -19,17 +19,22 @@
package org.elasticsearch.http.netty4; package org.elasticsearch.http.netty4;
import io.netty.channel.Channel; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.transport.netty4.Netty4Utils;
import java.net.SocketAddress;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -38,25 +43,16 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class Netty4HttpRequest extends RestRequest { public class Netty4HttpRequest implements HttpRequest {
private final FullHttpRequest request; private final FullHttpRequest request;
private final Channel channel;
private final BytesReference content; private final BytesReference content;
private final HttpHeadersMap headers;
private final int sequence;
/** Netty4HttpRequest(FullHttpRequest request, int sequence) {
* Construct a new request.
*
* @param xContentRegistry the content registry
* @param request the underlying request
* @param channel the channel for the request
* @throws BadParameterException if the parameters can not be decoded
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) {
super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers()));
this.request = request; this.request = request;
this.channel = channel; headers = new HttpHeadersMap(request.headers());
this.sequence = sequence;
if (request.content().isReadable()) { if (request.content().isReadable()) {
this.content = Netty4Utils.toBytesReference(request.content()); this.content = Netty4Utils.toBytesReference(request.content());
} else { } else {
@ -64,71 +60,39 @@ public class Netty4HttpRequest extends RestRequest {
} }
} }
/**
* Construct a new request. In contrast to
* {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so
* this constructor will not throw a {@link BadParameterException}.
*
* @param xContentRegistry the content registry
* @param params the parameters for the request
* @param uri the path for the request
* @param request the underlying request
* @param channel the channel for the request
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/
Netty4HttpRequest(
final NamedXContentRegistry xContentRegistry,
final Map<String, String> params,
final String uri,
final FullHttpRequest request,
final Channel channel) {
super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers()));
this.request = request;
this.channel = channel;
if (request.content().isReadable()) {
this.content = Netty4Utils.toBytesReference(request.content());
} else {
this.content = BytesArray.EMPTY;
}
}
public FullHttpRequest request() {
return this.request;
}
@Override @Override
public Method method() { public RestRequest.Method method() {
HttpMethod httpMethod = request.method(); HttpMethod httpMethod = request.method();
if (httpMethod == HttpMethod.GET) if (httpMethod == HttpMethod.GET)
return Method.GET; return RestRequest.Method.GET;
if (httpMethod == HttpMethod.POST) if (httpMethod == HttpMethod.POST)
return Method.POST; return RestRequest.Method.POST;
if (httpMethod == HttpMethod.PUT) if (httpMethod == HttpMethod.PUT)
return Method.PUT; return RestRequest.Method.PUT;
if (httpMethod == HttpMethod.DELETE) if (httpMethod == HttpMethod.DELETE)
return Method.DELETE; return RestRequest.Method.DELETE;
if (httpMethod == HttpMethod.HEAD) { if (httpMethod == HttpMethod.HEAD) {
return Method.HEAD; return RestRequest.Method.HEAD;
} }
if (httpMethod == HttpMethod.OPTIONS) { if (httpMethod == HttpMethod.OPTIONS) {
return Method.OPTIONS; return RestRequest.Method.OPTIONS;
} }
if (httpMethod == HttpMethod.PATCH) { if (httpMethod == HttpMethod.PATCH) {
return Method.PATCH; return RestRequest.Method.PATCH;
} }
if (httpMethod == HttpMethod.TRACE) { if (httpMethod == HttpMethod.TRACE) {
return Method.TRACE; return RestRequest.Method.TRACE;
} }
if (httpMethod == HttpMethod.CONNECT) { if (httpMethod == HttpMethod.CONNECT) {
return Method.CONNECT; return RestRequest.Method.CONNECT;
} }
throw new IllegalArgumentException("Unexpected http method: " + httpMethod); throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
@ -139,40 +103,64 @@ public class Netty4HttpRequest extends RestRequest {
return request.uri(); return request.uri();
} }
@Override
public boolean hasContent() {
return content.length() > 0;
}
@Override @Override
public BytesReference content() { public BytesReference content() {
return content; return content;
} }
/**
* Returns the remote address where this rest request channel is "connected to". The
* returned {@link SocketAddress} is supposed to be down-cast into more
* concrete type such as {@link java.net.InetSocketAddress} to retrieve
* the detailed information.
*/
@Override @Override
public SocketAddress getRemoteAddress() { public final Map<String, List<String>> getHeaders() {
return channel.remoteAddress(); return headers;
} }
/**
* Returns the local address where this request channel is bound to. The returned
* {@link SocketAddress} is supposed to be down-cast into more concrete
* type such as {@link java.net.InetSocketAddress} to retrieve the detailed
* information.
*/
@Override @Override
public SocketAddress getLocalAddress() { public List<String> strictCookies() {
return channel.localAddress(); String cookieString = request.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
return ServerCookieEncoder.STRICT.encode(cookies);
}
}
return Collections.emptyList();
} }
public Channel getChannel() { @Override
return channel; public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpRequest.HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
}
}
@Override
public HttpRequest removeHeader(String header) {
HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove(header);
HttpHeaders trailingHeaders = new DefaultHttpHeaders();
trailingHeaders.add(request.trailingHeaders());
trailingHeaders.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(),
request.content(), headersWithoutContentTypeHeader, trailingHeaders);
return new Netty4HttpRequest(requestWithoutHeader, sequence);
}
@Override
public Netty4HttpResponse createResponse(RestStatus status, BytesReference content) {
return new Netty4HttpResponse(this, status, content);
}
public FullHttpRequest nettyRequest() {
return request;
}
int sequence() {
return sequence;
} }
/** /**

View File

@ -20,43 +20,31 @@
package org.elasticsearch.http.netty4; package org.elasticsearch.http.netty4;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.transport.netty4.Netty4Utils; import org.elasticsearch.transport.netty4.Netty4Utils;
import java.util.Collections;
@ChannelHandler.Sharable @ChannelHandler.Sharable
class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> { class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
private final Netty4HttpServerTransport serverTransport; private final Netty4HttpServerTransport serverTransport;
private final HttpHandlingSettings handlingSettings;
private final ThreadContext threadContext;
Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings, Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport) {
ThreadContext threadContext) {
this.serverTransport = serverTransport; this.serverTransport = serverTransport;
this.handlingSettings = handlingSettings;
this.threadContext = threadContext;
} }
@Override @Override
protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> msg) throws Exception { protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> msg) throws Exception {
final FullHttpRequest request = msg.getRequest(); Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get();
FullHttpRequest request = msg.getRequest();
try { try {
final FullHttpRequest copiedRequest =
final FullHttpRequest copy =
new DefaultFullHttpRequest( new DefaultFullHttpRequest(
request.protocolVersion(), request.protocolVersion(),
request.method(), request.method(),
@ -65,67 +53,18 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
request.headers(), request.headers(),
request.trailingHeaders()); request.trailingHeaders());
Exception badRequestCause = null; Netty4HttpRequest httpRequest = new Netty4HttpRequest(copiedRequest, msg.getSequence());
/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final Netty4HttpRequest httpRequest;
{
Netty4HttpRequest innerHttpRequest;
try {
innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel());
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutParameters(copy, ctx.channel());
}
httpRequest = innerHttpRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these
* parameter values.
*/
final Netty4HttpChannel channel;
{
Netty4HttpChannel innerChannel;
try {
innerChannel =
new Netty4HttpChannel(serverTransport, httpRequest, msg.getSequence(), handlingSettings, threadContext);
} catch (final IllegalArgumentException e) {
if (badRequestCause == null) {
badRequestCause = e;
} else {
badRequestCause.addSuppressed(e);
}
final Netty4HttpRequest innerRequest =
new Netty4HttpRequest(
serverTransport.xContentRegistry,
Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters
copy.uri(),
copy,
ctx.channel());
innerChannel =
new Netty4HttpChannel(serverTransport, innerRequest, msg.getSequence(), handlingSettings, threadContext);
}
channel = innerChannel;
}
if (request.decoderResult().isFailure()) { if (request.decoderResult().isFailure()) {
serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); Throwable cause = request.decoderResult().cause();
} else if (badRequestCause != null) { if (cause instanceof Error) {
serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause); ExceptionsHelper.dieOnError(cause);
serverTransport.incomingRequestError(httpRequest, channel, new Exception(cause));
} else { } else {
serverTransport.dispatchRequest(httpRequest, channel); serverTransport.incomingRequestError(httpRequest, channel, (Exception) cause);
}
} else {
serverTransport.incomingRequest(httpRequest, channel);
} }
} finally { } finally {
// As we have copied the buffer, we can release the request // As we have copied the buffer, we can release the request
@ -133,32 +72,6 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler<HttpPipelined
} }
} }
private Netty4HttpRequest requestWithoutContentTypeHeader(
final FullHttpRequest request, final Channel channel, final Exception badRequestCause) {
final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove("Content-Type");
final FullHttpRequest requestWithoutContentTypeHeader =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
request.content(),
headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again
request.trailingHeaders()); // Content-Type can not be a trailing header
try {
return new Netty4HttpRequest(serverTransport.xContentRegistry, requestWithoutContentTypeHeader, channel);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return requestWithoutParameters(requestWithoutContentTypeHeader, channel);
}
}
private Netty4HttpRequest requestWithoutParameters(final FullHttpRequest request, final Channel channel) {
// remove all parameters as at least one is incorrectly encoded
return new Netty4HttpRequest(serverTransport.xContentRegistry, Collections.emptyMap(), request.uri(), request, channel);
}
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
Netty4Utils.maybeDie(cause); Netty4Utils.maybeDie(cause);

View File

@ -19,19 +19,103 @@
package org.elasticsearch.http.netty4; package org.elasticsearch.http.netty4;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedMessage; import org.elasticsearch.http.HttpPipelinedMessage;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;
public class Netty4HttpResponse extends HttpPipelinedMessage { import java.util.Collections;
import java.util.EnumMap;
import java.util.Map;
private final FullHttpResponse response; public class Netty4HttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage {
public Netty4HttpResponse(int sequence, FullHttpResponse response) { private final int sequence;
super(sequence); private final Netty4HttpRequest request;
this.response = response;
Netty4HttpResponse(Netty4HttpRequest request, RestStatus status, BytesReference content) {
super(request.nettyRequest().protocolVersion(), getStatus(status), Netty4Utils.toByteBuf(content));
this.sequence = request.sequence();
this.request = request;
} }
public FullHttpResponse getResponse() { @Override
return response; public void addHeader(String name, String value) {
headers().add(name, value);
} }
@Override
public boolean containsHeader(String name) {
return headers().contains(name);
} }
@Override
public int getSequence() {
return sequence;
}
public Netty4HttpRequest getRequest() {
return request;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
}
}

View File

@ -39,6 +39,7 @@ import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseEncoder;
import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.handler.timeout.ReadTimeoutHandler; import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.AttributeKey;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier; import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
@ -53,9 +54,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.AbstractHttpServerTransport;
import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.BindHttpException;
import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpHandlingSettings;
@ -149,38 +148,29 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
public static final Setting<ByteSizeValue> SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE = public static final Setting<ByteSizeValue> SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE =
Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope);
protected final BigArrays bigArrays; private final ByteSizeValue maxInitialLineLength;
private final ByteSizeValue maxHeaderSize;
private final ByteSizeValue maxChunkSize;
protected final ByteSizeValue maxInitialLineLength; private final int workerCount;
protected final ByteSizeValue maxHeaderSize;
protected final ByteSizeValue maxChunkSize;
protected final int workerCount; private final int pipeliningMaxEvents;
protected final int pipeliningMaxEvents; private final boolean tcpNoDelay;
private final boolean tcpKeepAlive;
private final boolean reuseAddress;
/** private final ByteSizeValue tcpSendBufferSize;
* The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}. private final ByteSizeValue tcpReceiveBufferSize;
*/ private final RecvByteBufAllocator recvByteBufAllocator;
protected final NamedXContentRegistry xContentRegistry;
protected final boolean tcpNoDelay;
protected final boolean tcpKeepAlive;
protected final boolean reuseAddress;
protected final ByteSizeValue tcpSendBufferSize;
protected final ByteSizeValue tcpReceiveBufferSize;
protected final RecvByteBufAllocator recvByteBufAllocator;
private final int readTimeoutMillis; private final int readTimeoutMillis;
protected final int maxCompositeBufferComponents; private final int maxCompositeBufferComponents;
protected volatile ServerBootstrap serverBootstrap; protected volatile ServerBootstrap serverBootstrap;
protected final List<Channel> serverChannels = new ArrayList<>(); protected final List<Channel> serverChannels = new ArrayList<>();
protected final HttpHandlingSettings httpHandlingSettings;
// package private for testing // package private for testing
Netty4OpenChannelsHandler serverOpenChannels; Netty4OpenChannelsHandler serverOpenChannels;
@ -189,16 +179,13 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) {
super(settings, networkService, threadPool, dispatcher); super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings));
this.bigArrays = bigArrays;
this.xContentRegistry = xContentRegistry;
this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings);
this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings); this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings);
this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings); this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings);
@ -398,26 +385,27 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
} }
public ChannelHandler configureServerChannelHandler() { public ChannelHandler configureServerChannelHandler() {
return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext()); return new HttpChannelHandler(this, handlingSettings);
} }
static final AttributeKey<Netty4HttpChannel> HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel");
protected static class HttpChannelHandler extends ChannelInitializer<Channel> { protected static class HttpChannelHandler extends ChannelInitializer<Channel> {
private final Netty4HttpServerTransport transport; private final Netty4HttpServerTransport transport;
private final Netty4HttpRequestHandler requestHandler; private final Netty4HttpRequestHandler requestHandler;
private final HttpHandlingSettings handlingSettings; private final HttpHandlingSettings handlingSettings;
protected HttpChannelHandler( protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) {
final Netty4HttpServerTransport transport,
final HttpHandlingSettings handlingSettings,
final ThreadContext threadContext) {
this.transport = transport; this.transport = transport;
this.handlingSettings = handlingSettings; this.handlingSettings = handlingSettings;
this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext); this.requestHandler = new Netty4HttpRequestHandler(transport);
} }
@Override @Override
protected void initChannel(Channel ch) throws Exception { protected void initChannel(Channel ch) throws Exception {
Netty4HttpChannel nettyTcpChannel = new Netty4HttpChannel(ch);
ch.attr(HTTP_CHANNEL_KEY).set(nettyTcpChannel);
ch.pipeline().addLast("openChannels", transport.serverOpenChannels); ch.pipeline().addLast("openChannels", transport.serverOpenChannels);
ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS));
final HttpRequestDecoder decoder = new HttpRequestDecoder( final HttpRequestDecoder decoder = new HttpRequestDecoder(

View File

@ -22,6 +22,7 @@ package org.elasticsearch.http.netty4.cors;
import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.http.netty4.Netty4HttpResponse;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -76,6 +78,14 @@ public class Netty4CorsHandler extends ChannelDuplexHandler {
ctx.fireChannelRead(msg); ctx.fireChannelRead(msg);
} }
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass();
Netty4HttpResponse response = (Netty4HttpResponse) msg;
setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config);
ctx.write(response, promise);;
}
public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) { public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) {
if (!config.isCorsSupportEnabled()) { if (!config.isCorsSupportEnabled()) {
return; return;

View File

@ -333,10 +333,10 @@ public class Netty4Transport extends TcpTransport {
addClosedExceptionLogger(ch); addClosedExceptionLogger(ch);
NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name); NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name);
ch.attr(CHANNEL_KEY).set(nettyTcpChannel); ch.attr(CHANNEL_KEY).set(nettyTcpChannel);
serverAcceptedChannel(nettyTcpChannel);
ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("logging", new ESLoggingHandler());
ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder());
ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name)); ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name));
serverAcceptedChannel(nettyTcpChannel);
} }
@Override @Override

View File

@ -98,9 +98,12 @@ public class NettyTcpChannel implements TcpChannel {
} else { } else {
final Throwable cause = f.cause(); final Throwable cause = f.cause();
Netty4Utils.maybeDie(cause); Netty4Utils.maybeDie(cause);
assert cause instanceof Exception; if (cause instanceof Error) {
listener.onFailure(new Exception(cause));
} else {
listener.onFailure((Exception) cause); listener.onFailure((Exception) cause);
} }
}
}); });
channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise); channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise);

View File

@ -0,0 +1,148 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class Netty4CorsTests extends ESTestCase {
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = Netty4CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
EmbeddedChannel embeddedChannel = new EmbeddedChannel();
embeddedChannel.pipeline().addLast(new Netty4CorsHandler(Netty4HttpServerTransport.buildCorsConfig(settings)));
Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest, 0);
embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content")));
return embeddedChannel.readOutbound();
}
}

View File

@ -1,616 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelId;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelProgressivePromise;
import io.netty.channel.ChannelPromise;
import io.netty.channel.EventLoop;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasablePagedBytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ByteArray;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class Netty4HttpChannelTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Collections.emptyList());
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(Netty4Utils.toByteBuf(new TestResponse().content())));
}
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = Netty4CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
public void testHeadersSet() {
Settings settings = Settings.builder().build();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(),
new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote");
final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
// send a response
Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
TestResponse resp = new TestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
HttpResponse response = ((Netty4HttpResponse) writtenObjects.get(0)).getResponse();
assertThat(response.headers().get("non-existent-header"), nullValue());
assertThat(response.headers().get(customHeader), equalTo(customHeaderValue));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length())));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType()));
}
}
public void testReleaseOnSendToClosedChannel() {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) {
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final TestResponse response = new TestResponse(bigArrays);
assertThat(response.content(), instanceOf(Releasable.class));
embeddedChannel.close();
channel.sendResponse(response);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
}
public void testReleaseOnSendToChannelAfterException() throws IOException {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) {
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
}
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}
} else {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/");
if (!close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
}
}
final EmbeddedChannel embeddedChannel = new EmbeddedChannel();
final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, embeddedChannel);
// send a response, the channel close status should match
assertTrue(embeddedChannel.isOpen());
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
final Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
final TestResponse resp = new TestResponse();
channel.sendResponse(resp);
assertThat(embeddedChannel.isOpen(), equalTo(!close));
}
}
private FullHttpResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
try (Netty4HttpServerTransport httpServerTransport =
new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(),
new NullDispatcher())) {
httpServerTransport.start();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel();
final Netty4HttpRequest request =
new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel);
HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings;
Netty4HttpChannel channel =
new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext());
channel.sendResponse(new TestResponse());
// get the response
List<Object> writtenObjects = writeCapturingChannel.getWrittenObjects();
assertThat(writtenObjects.size(), is(1));
return ((Netty4HttpResponse) writtenObjects.get(0)).getResponse();
}
}
private static class WriteCapturingChannel implements Channel {
private List<Object> writtenObjects = new ArrayList<>();
@Override
public ChannelId id() {
return null;
}
@Override
public EventLoop eventLoop() {
return null;
}
@Override
public Channel parent() {
return null;
}
@Override
public ChannelConfig config() {
return null;
}
@Override
public boolean isOpen() {
return false;
}
@Override
public boolean isRegistered() {
return false;
}
@Override
public boolean isActive() {
return false;
}
@Override
public ChannelMetadata metadata() {
return null;
}
@Override
public SocketAddress localAddress() {
return null;
}
@Override
public SocketAddress remoteAddress() {
return null;
}
@Override
public ChannelFuture closeFuture() {
return null;
}
@Override
public boolean isWritable() {
return false;
}
@Override
public long bytesBeforeUnwritable() {
return 0;
}
@Override
public long bytesBeforeWritable() {
return 0;
}
@Override
public Unsafe unsafe() {
return null;
}
@Override
public ChannelPipeline pipeline() {
return null;
}
@Override
public ByteBufAllocator alloc() {
return null;
}
@Override
public Channel read() {
return null;
}
@Override
public Channel flush() {
return null;
}
@Override
public ChannelFuture bind(SocketAddress localAddress) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) {
return null;
}
@Override
public ChannelFuture disconnect() {
return null;
}
@Override
public ChannelFuture close() {
return null;
}
@Override
public ChannelFuture deregister() {
return null;
}
@Override
public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture disconnect(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture close(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture deregister(ChannelPromise promise) {
return null;
}
@Override
public ChannelFuture write(Object msg) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture write(Object msg, ChannelPromise promise) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelFuture writeAndFlush(Object msg) {
writtenObjects.add(msg);
return null;
}
@Override
public ChannelPromise newPromise() {
return null;
}
@Override
public ChannelProgressivePromise newProgressivePromise() {
return null;
}
@Override
public ChannelFuture newSucceededFuture() {
return null;
}
@Override
public ChannelFuture newFailedFuture(Throwable cause) {
return null;
}
@Override
public ChannelPromise voidPromise() {
return null;
}
@Override
public <T> Attribute<T> attr(AttributeKey<T> key) {
return null;
}
@Override
public <T> boolean hasAttr(AttributeKey<T> key) {
return false;
}
@Override
public int compareTo(Channel o) {
return 0;
}
List<Object> getWrittenObjects() {
return writtenObjects;
}
}
private static class TestResponse extends RestResponse {
private final BytesReference reference;
TestResponse() {
reference = Netty4Utils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8));
}
TestResponse(final BigArrays bigArrays) {
final byte[] bytes;
try {
bytes = "content".getBytes("UTF-8");
} catch (final UnsupportedEncodingException e) {
throw new AssertionError(e);
}
final ByteArray bigArray = bigArrays.newByteArray(bytes.length);
bigArray.set(0, bytes, 0, bytes.length);
reference = new ReleasablePagedBytesReference(bigArrays, bigArray, bytes.length, Releasables.releaseOnce(bigArray));
}
@Override
public String contentType() {
return "text";
}
@Override
public BytesReference content() {
return reference;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -19,15 +19,12 @@
package org.elasticsearch.http.netty4; package org.elasticsearch.http.netty4;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.QueryStringDecoder;
import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.After; import org.junit.After;
@ -55,7 +55,6 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
@ -191,11 +190,11 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
ArrayList<ChannelPromise> promises = new ArrayList<>(); ArrayList<ChannelPromise> promises = new ArrayList<>();
for (int i = 1; i < requests.size(); ++i) { for (int i = 1; i < requests.size(); ++i) {
final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK);
ChannelPromise promise = embeddedChannel.newPromise(); ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise); promises.add(promise);
int sequence = requests.get(i).getSequence(); HttpPipelinedRequest<FullHttpRequest> pipelinedRequest = requests.get(i);
Netty4HttpResponse resp = new Netty4HttpResponse(sequence, httpResponse); Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
Netty4HttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.writeAndFlush(resp, promise); embeddedChannel.writeAndFlush(resp, promise);
} }
@ -233,10 +232,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
} }
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<LastHttpContent>> { private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
@Override @Override
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<LastHttpContent> pipelinedRequest) { protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> pipelinedRequest) {
LastHttpContent request = pipelinedRequest.getRequest(); LastHttpContent request = pipelinedRequest.getRequest();
final QueryStringDecoder decoder; final QueryStringDecoder decoder;
if (request instanceof FullHttpRequest) { if (request instanceof FullHttpRequest) {
@ -246,9 +245,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
} }
final String uri = decoder.path().replace("/", ""); final String uri = decoder.path().replace("/", "");
final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8));
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); Netty4HttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content);
httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length()));
final CountDownLatch waitingLatch = new CountDownLatch(1); final CountDownLatch waitingLatch = new CountDownLatch(1);
waitingRequests.put(uri, waitingLatch); waitingRequests.put(uri, waitingLatch);
@ -260,7 +260,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
waitingLatch.await(1000, TimeUnit.SECONDS); waitingLatch.await(1000, TimeUnit.SECONDS);
final ChannelPromise promise = ctx.newPromise(); final ChannelPromise promise = ctx.newPromise();
eventLoopService.submit(() -> { eventLoopService.submit(() -> {
ctx.write(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); ctx.write(httpResponse, promise);
finishingLatch.countDown(); finishingLatch.countDown();
}); });
} catch (InterruptedException e) { } catch (InterruptedException e) {

View File

@ -26,22 +26,20 @@ import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.bytes.BytesArray;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler; import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.http.NullDispatcher;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -120,7 +118,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
@Override @Override
public ChannelHandler configureServerChannelHandler() { public ChannelHandler configureServerChannelHandler() {
return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext()); return new CustomHttpChannelHandler(this, executorService);
} }
@Override @Override
@ -135,8 +133,8 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
private final ExecutorService executorService; private final ExecutorService executorService;
CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) { CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService) {
super(transport, transport.httpHandlingSettings, threadContext); super(transport, transport.handlingSettings);
this.executorService = executorService; this.executorService = executorService;
} }
@ -187,8 +185,9 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8);
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); Netty4HttpRequest httpRequest = new Netty4HttpRequest(fullHttpRequest, pipelinedRequest.getSequence());
httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); Netty4HttpResponse response = httpRequest.createResponse(RestStatus.OK, new BytesArray(uri.getBytes(StandardCharsets.UTF_8)));
response.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes());
final boolean slow = uri.matches("/slow/\\d+"); final boolean slow = uri.matches("/slow/\\d+");
if (slow) { if (slow) {
@ -202,7 +201,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase {
} }
final ChannelPromise promise = ctx.newPromise(); final ChannelPromise promise = ctx.newPromise();
ctx.writeAndFlush(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); ctx.writeAndFlush(response, promise);
} }
} }

View File

@ -291,40 +291,6 @@ public class Netty4HttpServerTransportTests extends ESTestCase {
assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); assertThat(causeReference.get(), instanceOf(TooLongFrameException.class));
} }
public void testDispatchDoesNotModifyThreadContext() throws InterruptedException {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (Netty4HttpServerTransport transport =
new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
transport.start();
transport.dispatchRequest(null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchBadRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
public void testReadTimeout() throws Exception { public void testReadTimeout() throws Exception {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {

View File

@ -23,54 +23,38 @@ import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentCompressor;
import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpResponseEncoder;
import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.ReadWriteHandler;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.nio.WriteOperation;
import org.elasticsearch.rest.RestRequest;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
public class HttpReadWriteHandler implements ReadWriteHandler { public class HttpReadWriteHandler implements ReadWriteHandler {
private final NettyAdaptor adaptor; private final NettyAdaptor adaptor;
private final NioSocketChannel nioChannel; private final NioHttpChannel nioHttpChannel;
private final NioHttpServerTransport transport; private final NioHttpServerTransport transport;
private final HttpHandlingSettings settings;
private final NamedXContentRegistry xContentRegistry;
private final NioCorsConfig corsConfig;
private final ThreadContext threadContext;
HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
NamedXContentRegistry xContentRegistry, NioCorsConfig corsConfig, ThreadContext threadContext) { NioCorsConfig corsConfig) {
this.nioChannel = nioChannel; this.nioHttpChannel = nioHttpChannel;
this.transport = transport; this.transport = transport;
this.settings = settings;
this.xContentRegistry = xContentRegistry;
this.corsConfig = corsConfig;
this.threadContext = threadContext;
List<ChannelHandler> handlers = new ArrayList<>(5); List<ChannelHandler> handlers = new ArrayList<>(5);
HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(), HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(),
@ -89,7 +73,7 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents())); handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents()));
adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0]));
adaptor.addCloseListener((v, e) -> nioChannel.close()); adaptor.addCloseListener((v, e) -> nioHttpChannel.close());
} }
@Override @Override
@ -150,95 +134,22 @@ public class HttpReadWriteHandler implements ReadWriteHandler {
request.headers(), request.headers(),
request.trailingHeaders()); request.trailingHeaders());
Exception badRequestCause = null; NioHttpRequest httpRequest = new NioHttpRequest(copiedRequest, pipelinedRequest.getSequence());
/*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final NioHttpRequest httpRequest;
{
NioHttpRequest innerHttpRequest;
try {
innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = e;
innerHttpRequest = requestWithoutParameters(copiedRequest);
}
httpRequest = innerHttpRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of
* these parameter values.
*/
final NioHttpChannel channel;
{
NioHttpChannel innerChannel;
int sequence = pipelinedRequest.getSequence();
BigArrays bigArrays = transport.getBigArrays();
try {
innerChannel = new NioHttpChannel(nioChannel, bigArrays, httpRequest, sequence, settings, corsConfig, threadContext);
} catch (final IllegalArgumentException e) {
if (badRequestCause == null) {
badRequestCause = e;
} else {
badRequestCause.addSuppressed(e);
}
final NioHttpRequest innerRequest =
new NioHttpRequest(
xContentRegistry,
Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters
copiedRequest.uri(),
copiedRequest);
innerChannel = new NioHttpChannel(nioChannel, bigArrays, innerRequest, sequence, settings, corsConfig, threadContext);
}
channel = innerChannel;
}
if (request.decoderResult().isFailure()) { if (request.decoderResult().isFailure()) {
transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); Throwable cause = request.decoderResult().cause();
} else if (badRequestCause != null) { if (cause instanceof Error) {
transport.dispatchBadRequest(httpRequest, channel, badRequestCause); ExceptionsHelper.dieOnError(cause);
transport.incomingRequestError(httpRequest, nioHttpChannel, new Exception(cause));
} else { } else {
transport.dispatchRequest(httpRequest, channel); transport.incomingRequestError(httpRequest, nioHttpChannel, (Exception) cause);
}
} else {
transport.incomingRequest(httpRequest, nioHttpChannel);
} }
} finally { } finally {
// As we have copied the buffer, we can release the request // As we have copied the buffer, we can release the request
request.release(); request.release();
} }
} }
private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) {
final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove("Content-Type");
final FullHttpRequest requestWithoutContentTypeHeader =
new DefaultFullHttpRequest(
request.protocolVersion(),
request.method(),
request.uri(),
request.content(),
headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again
request.trailingHeaders()); // Content-Type can not be a trailing header
try {
return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return requestWithoutParameters(requestWithoutContentTypeHeader);
}
}
private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) {
// remove all parameters as at least one is incorrectly encoded
return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request);
}
} }

View File

@ -19,244 +19,21 @@
package org.elasticsearch.http.nio; package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf; import org.elasticsearch.action.ActionListener;
import io.netty.buffer.Unpooled; import org.elasticsearch.http.HttpChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse; import org.elasticsearch.http.HttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import java.util.ArrayList; import java.io.IOException;
import java.util.Collections; import java.nio.channels.SocketChannel;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
public class NioHttpChannel extends AbstractRestChannel { public class NioHttpChannel extends NioSocketChannel implements HttpChannel {
private final BigArrays bigArrays; NioHttpChannel(SocketChannel socketChannel) throws IOException {
private final int sequence; super(socketChannel);
private final NioCorsConfig corsConfig;
private final ThreadContext threadContext;
private final FullHttpRequest nettyRequest;
private final NioSocketChannel nioChannel;
private final boolean resetCookies;
NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, int sequence,
HttpHandlingSettings settings, NioCorsConfig corsConfig, ThreadContext threadContext) {
super(request, settings.getDetailedErrorsEnabled());
this.nioChannel = nioChannel;
this.bigArrays = bigArrays;
this.sequence = sequence;
this.corsConfig = corsConfig;
this.threadContext = threadContext;
this.nettyRequest = request.getRequest();
this.resetCookies = settings.isResetCookies();
} }
@Override public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
public void sendResponse(RestResponse response) { getContext().sendMessage(response, ActionListener.toBiConsumer(listener));
// if the response object was created upstream, then use it;
// otherwise, create a new one
ByteBuf buffer = ByteBufUtils.toByteBuf(response.content());
final FullHttpResponse resp;
if (HttpMethod.HEAD.equals(nettyRequest.method())) {
resp = newResponse(Unpooled.EMPTY_BUFFER);
} else {
resp = newResponse(buffer);
}
resp.setStatus(getStatus(response.status()));
NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
String opaque = nettyRequest.headers().get("X-Opaque-Id");
if (opaque != null) {
setHeaderField(resp, "X-Opaque-Id", opaque);
}
// Add all custom headers
addCustomHeaders(resp, response.getHeaders());
addCustomHeaders(resp, threadContext.getResponseHeaders());
ArrayList<Releasable> toClose = new ArrayList<>(3);
boolean success = false;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false);
addCookies(resp);
BytesReference content = response.content();
if (content instanceof Releasable) {
toClose.add((Releasable) content);
}
BytesStreamOutput bytesStreamOutput = bytesOutputOrNull();
if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) {
toClose.add((Releasable) bytesStreamOutput);
}
if (isCloseConnection()) {
toClose.add(nioChannel::close);
}
BiConsumer<Void, Exception> listener = (aVoid, ex) -> Releasables.close(toClose);
nioChannel.getContext().sendMessage(new NioHttpResponse(sequence, resp), listener);
success = true;
} finally {
if (success == false) {
Releasables.close(toClose);
}
}
}
@Override
protected BytesStreamOutput newBytesOutput() {
return new ReleasableBytesStreamOutput(bigArrays);
}
private void setHeaderField(HttpResponse resp, String headerField, String value) {
setHeaderField(resp, headerField, value, true);
}
private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) {
if (override || !resp.headers().contains(headerField)) {
resp.headers().add(headerField, value);
}
}
private void addCookies(HttpResponse resp) {
if (resetCookies) {
String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
// Reset the cookies if necessary.
resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies));
}
}
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
// Create a new {@link HttpResponse} to transmit the response for the netty request.
private FullHttpResponse newResponse(ByteBuf buffer) {
final boolean http10 = isHttp10();
final boolean close = isCloseConnection();
// Build the response object.
final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize
final FullHttpResponse response;
if (http10) {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer);
if (!close) {
response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive");
}
} else {
response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer);
}
return response;
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0);
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) ||
(http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)));
}
private static Map<RestStatus, HttpResponseStatus> MAP;
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
} }
} }

View File

@ -68,7 +68,7 @@ public class NioHttpPipeliningHandler extends ChannelDuplexHandler {
List<Tuple<NioHttpResponse, NettyListener>> readyResponses = aggregator.write(response, listener); List<Tuple<NioHttpResponse, NettyListener>> readyResponses = aggregator.write(response, listener);
success = true; success = true;
for (Tuple<NioHttpResponse, NettyListener> responseToWrite : readyResponses) { for (Tuple<NioHttpResponse, NettyListener> responseToWrite : readyResponses) {
ctx.write(responseToWrite.v1().getResponse(), responseToWrite.v2()); ctx.write(responseToWrite.v1(), responseToWrite.v2());
} }
} catch (IllegalStateException e) { } catch (IllegalStateException e) {
ctx.channel().close(); ctx.channel().close();

View File

@ -19,13 +19,20 @@
package org.elasticsearch.http.nio; package org.elasticsearch.http.nio;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.cookie.Cookie;
import io.netty.handler.codec.http.cookie.ServerCookieDecoder;
import io.netty.handler.codec.http.cookie.ServerCookieEncoder;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.AbstractMap; import java.util.AbstractMap;
import java.util.Collection; import java.util.Collection;
@ -35,25 +42,17 @@ import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class NioHttpRequest extends RestRequest { public class NioHttpRequest implements HttpRequest {
private final FullHttpRequest request; private final FullHttpRequest request;
private final BytesReference content; private final BytesReference content;
private final HttpHeadersMap headers;
private final int sequence;
NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) { NioHttpRequest(FullHttpRequest request, int sequence) {
super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers()));
this.request = request;
if (request.content().isReadable()) {
this.content = ByteBufUtils.toBytesReference(request.content());
} else {
this.content = BytesArray.EMPTY;
}
}
NioHttpRequest(NamedXContentRegistry xContentRegistry, Map<String, String> params, String uri, FullHttpRequest request) {
super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers()));
this.request = request; this.request = request;
headers = new HttpHeadersMap(request.headers());
this.sequence = sequence;
if (request.content().isReadable()) { if (request.content().isReadable()) {
this.content = ByteBufUtils.toBytesReference(request.content()); this.content = ByteBufUtils.toBytesReference(request.content());
} else { } else {
@ -62,38 +61,38 @@ public class NioHttpRequest extends RestRequest {
} }
@Override @Override
public Method method() { public RestRequest.Method method() {
HttpMethod httpMethod = request.method(); HttpMethod httpMethod = request.method();
if (httpMethod == HttpMethod.GET) if (httpMethod == HttpMethod.GET)
return Method.GET; return RestRequest.Method.GET;
if (httpMethod == HttpMethod.POST) if (httpMethod == HttpMethod.POST)
return Method.POST; return RestRequest.Method.POST;
if (httpMethod == HttpMethod.PUT) if (httpMethod == HttpMethod.PUT)
return Method.PUT; return RestRequest.Method.PUT;
if (httpMethod == HttpMethod.DELETE) if (httpMethod == HttpMethod.DELETE)
return Method.DELETE; return RestRequest.Method.DELETE;
if (httpMethod == HttpMethod.HEAD) { if (httpMethod == HttpMethod.HEAD) {
return Method.HEAD; return RestRequest.Method.HEAD;
} }
if (httpMethod == HttpMethod.OPTIONS) { if (httpMethod == HttpMethod.OPTIONS) {
return Method.OPTIONS; return RestRequest.Method.OPTIONS;
} }
if (httpMethod == HttpMethod.PATCH) { if (httpMethod == HttpMethod.PATCH) {
return Method.PATCH; return RestRequest.Method.PATCH;
} }
if (httpMethod == HttpMethod.TRACE) { if (httpMethod == HttpMethod.TRACE) {
return Method.TRACE; return RestRequest.Method.TRACE;
} }
if (httpMethod == HttpMethod.CONNECT) { if (httpMethod == HttpMethod.CONNECT) {
return Method.CONNECT; return RestRequest.Method.CONNECT;
} }
throw new IllegalArgumentException("Unexpected http method: " + httpMethod); throw new IllegalArgumentException("Unexpected http method: " + httpMethod);
@ -104,20 +103,66 @@ public class NioHttpRequest extends RestRequest {
return request.uri(); return request.uri();
} }
@Override
public boolean hasContent() {
return content.length() > 0;
}
@Override @Override
public BytesReference content() { public BytesReference content() {
return content; return content;
} }
public FullHttpRequest getRequest() {
@Override
public final Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
String cookieString = request.headers().get(HttpHeaderNames.COOKIE);
if (cookieString != null) {
Set<Cookie> cookies = ServerCookieDecoder.STRICT.decode(cookieString);
if (!cookies.isEmpty()) {
return ServerCookieEncoder.STRICT.encode(cookies);
}
}
return Collections.emptyList();
}
@Override
public HttpVersion protocolVersion() {
if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) {
return HttpRequest.HttpVersion.HTTP_1_0;
} else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) {
return HttpRequest.HttpVersion.HTTP_1_1;
} else {
throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion());
}
}
@Override
public HttpRequest removeHeader(String header) {
HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders();
headersWithoutContentTypeHeader.add(request.headers());
headersWithoutContentTypeHeader.remove(header);
HttpHeaders trailingHeaders = new DefaultHttpHeaders();
trailingHeaders.add(request.trailingHeaders());
trailingHeaders.remove(header);
FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(),
request.content(), headersWithoutContentTypeHeader, trailingHeaders);
return new NioHttpRequest(requestWithoutHeader, sequence);
}
@Override
public NioHttpResponse createResponse(RestStatus status, BytesReference content) {
return new NioHttpResponse(this, status, content);
}
public FullHttpRequest nettyRequest() {
return request; return request;
} }
int sequence() {
return sequence;
}
/** /**
* A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications
* and due to the underlying implementation, it performs case insensitive lookups of key to values. * and due to the underlying implementation, it performs case insensitive lookups of key to values.

View File

@ -19,19 +19,100 @@
package org.elasticsearch.http.nio; package org.elasticsearch.http.nio;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedMessage; import org.elasticsearch.http.HttpPipelinedMessage;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestStatus;
public class NioHttpResponse extends HttpPipelinedMessage { import java.util.Collections;
import java.util.EnumMap;
import java.util.Map;
private final FullHttpResponse response; public class NioHttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage {
public NioHttpResponse(int sequence, FullHttpResponse response) { private final int sequence;
super(sequence); private final NioHttpRequest request;
this.response = response;
NioHttpResponse(NioHttpRequest request, RestStatus status, BytesReference content) {
super(request.nettyRequest().protocolVersion(), getStatus(status), ByteBufUtils.toByteBuf(content));
this.sequence = request.sequence();
this.request = request;
} }
public FullHttpResponse getResponse() { @Override
return response; public void addHeader(String name, String value) {
headers().add(name, value);
}
@Override
public boolean containsHeader(String name) {
return headers().contains(name);
}
@Override
public int getSequence() {
return sequence;
}
private static Map<RestStatus, HttpResponseStatus> MAP;
public NioHttpRequest getRequest() {
return request;
}
static {
EnumMap<RestStatus, HttpResponseStatus> map = new EnumMap<>(RestStatus.class);
map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE);
map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS);
map.put(RestStatus.OK, HttpResponseStatus.OK);
map.put(RestStatus.CREATED, HttpResponseStatus.CREATED);
map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED);
map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION);
map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT);
map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT);
map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT);
map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this??
map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES);
map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY);
map.put(RestStatus.FOUND, HttpResponseStatus.FOUND);
map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER);
map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED);
map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY);
map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT);
map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED);
map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED);
map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN);
map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND);
map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED);
map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE);
map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED);
map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT);
map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT);
map.put(RestStatus.GONE, HttpResponseStatus.GONE);
map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED);
map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED);
map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE);
map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG);
map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE);
map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE);
map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED);
map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST);
map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS);
map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR);
map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED);
map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY);
map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE);
map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT);
map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED);
MAP = Collections.unmodifiableMap(map);
}
private static HttpResponseStatus getStatus(RestStatus status) {
return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR);
} }
} }

View File

@ -42,7 +42,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.AbstractHttpServerTransport;
import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.BindHttpException;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.http.HttpStats; import org.elasticsearch.http.HttpStats;
import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsConfig;
@ -53,11 +52,11 @@ import org.elasticsearch.nio.EventHandler;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannel; import org.elasticsearch.nio.NioChannel;
import org.elasticsearch.nio.NioGroup; import org.elasticsearch.nio.NioGroup;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.ServerChannelContext;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.nio.NioSelector;
import org.elasticsearch.rest.RestUtils; import org.elasticsearch.rest.RestUtils;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -104,12 +103,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
(s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2),
(s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope); (s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope);
private final BigArrays bigArrays;
private final ThreadPool threadPool;
private final NamedXContentRegistry xContentRegistry;
private final HttpHandlingSettings httpHandlingSettings;
private final boolean tcpNoDelay; private final boolean tcpNoDelay;
private final boolean tcpKeepAlive; private final boolean tcpKeepAlive;
private final boolean reuseAddress; private final boolean reuseAddress;
@ -124,16 +117,12 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) { NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) {
super(settings, networkService, threadPool, dispatcher); super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher);
this.bigArrays = bigArrays;
this.threadPool = threadPool;
this.xContentRegistry = xContentRegistry;
ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings);
ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings);
ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings);
int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings);
this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);;
this.corsConfig = buildCorsConfig(settings); this.corsConfig = buildCorsConfig(settings);
this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings);
@ -148,10 +137,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents); maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents);
} }
BigArrays getBigArrays() {
return bigArrays;
}
public Logger getLogger() { public Logger getLogger() {
return logger; return logger;
} }
@ -335,17 +320,17 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
socketChannels.add(socketChannel); socketChannels.add(socketChannel);
} }
private class HttpChannelFactory extends ChannelFactory<NioServerSocketChannel, NioSocketChannel> { private class HttpChannelFactory extends ChannelFactory<NioServerSocketChannel, NioHttpChannel> {
private HttpChannelFactory() { private HttpChannelFactory() {
super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize));
} }
@Override @Override
public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException {
NioSocketChannel nioChannel = new NioSocketChannel(channel); NioHttpChannel nioChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this, HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this,
httpHandlingSettings, xContentRegistry, corsConfig, threadPool.getThreadContext()); handlingSettings, corsConfig);
Consumer<Exception> exceptionHandler = (e) -> exceptionCaught(nioChannel, e); Consumer<Exception> exceptionHandler = (e) -> exceptionCaught(nioChannel, e);
SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline, SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline,
InboundChannelBuffer.allocatingInstance()); InboundChannelBuffer.allocatingInstance());

View File

@ -22,6 +22,7 @@ package org.elasticsearch.http.nio.cors;
import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpHeaders;
@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.http.nio.NioHttpResponse;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -76,6 +78,14 @@ public class NioCorsHandler extends ChannelDuplexHandler {
ctx.fireChannelRead(msg); ctx.fireChannelRead(msg);
} }
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass();
NioHttpResponse response = (NioHttpResponse) msg;
setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config);
ctx.write(response, promise);
}
public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) { public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) {
if (!config.isCorsSupportEnabled()) { if (!config.isCorsSupportEnabled()) {
return; return;

View File

@ -23,29 +23,31 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestEncoder; import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion; import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder; import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.Before; import org.junit.Before;
import org.mockito.ArgumentCaptor; import org.mockito.ArgumentCaptor;
@ -55,6 +57,9 @@ import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL;
@ -64,7 +69,12 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEAD
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS; import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
@ -72,7 +82,7 @@ import static org.mockito.Mockito.verify;
public class HttpReadWriteHandlerTests extends ESTestCase { public class HttpReadWriteHandlerTests extends ESTestCase {
private HttpReadWriteHandler handler; private HttpReadWriteHandler handler;
private NioSocketChannel nioSocketChannel; private NioHttpChannel nioHttpChannel;
private NioHttpServerTransport transport; private NioHttpServerTransport transport;
private final RequestEncoder requestEncoder = new RequestEncoder(); private final RequestEncoder requestEncoder = new RequestEncoder();
@ -96,15 +106,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings), SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings),
SETTING_PIPELINING_MAX_EVENTS.getDefault(settings), SETTING_PIPELINING_MAX_EVENTS.getDefault(settings),
SETTING_CORS_ENABLED.getDefault(settings)); SETTING_CORS_ENABLED.getDefault(settings));
ThreadContext threadContext = new ThreadContext(settings); nioHttpChannel = mock(NioHttpChannel.class);
nioSocketChannel = mock(NioSocketChannel.class); handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build());
handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY,
NioCorsConfigBuilder.forAnyOrigin().build(), threadContext);
} }
public void testSuccessfulDecodeHttpRequest() throws IOException { public void testSuccessfulDecodeHttpRequest() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8); String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
ByteBuf buf = requestEncoder.encode(httpRequest); ByteBuf buf = requestEncoder.encode(httpRequest);
int slicePoint = randomInt(buf.writerIndex() - 1); int slicePoint = randomInt(buf.writerIndex() - 1);
@ -113,22 +121,21 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex());
handler.consumeReads(toChannelBuffer(slicedBuf)); handler.consumeReads(toChannelBuffer(slicedBuf));
verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class)); verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class));
handler.consumeReads(toChannelBuffer(slicedBuf2)); handler.consumeReads(toChannelBuffer(slicedBuf2));
ArgumentCaptor<RestRequest> requestCaptor = ArgumentCaptor.forClass(RestRequest.class); ArgumentCaptor<HttpRequest> requestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class));
NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); HttpRequest nioHttpRequest = requestCaptor.getValue();
FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest(); assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion()); assertEquals(RestRequest.Method.GET, nioHttpRequest.method());
assertEquals(httpRequest.method(), nettyHttpRequest.method());
} }
public void testDecodeHttpRequestError() throws IOException { public void testDecodeHttpRequestError() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8); String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri);
ByteBuf buf = requestEncoder.encode(httpRequest); ByteBuf buf = requestEncoder.encode(httpRequest);
buf.setByte(0, ' '); buf.setByte(0, ' ');
@ -137,15 +144,15 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
handler.consumeReads(toChannelBuffer(buf)); handler.consumeReads(toChannelBuffer(buf));
ArgumentCaptor<Throwable> exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture()); verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture());
assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException);
} }
public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException {
String uri = "localhost:9090/" + randomAlphaOfLength(8); String uri = "localhost:9090/" + randomAlphaOfLength(8);
HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false);
HttpUtil.setContentLength(httpRequest, 1025); HttpUtil.setContentLength(httpRequest, 1025);
HttpUtil.setKeepAlive(httpRequest, false); HttpUtil.setKeepAlive(httpRequest, false);
@ -153,60 +160,176 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
handler.consumeReads(toChannelBuffer(buf)); handler.consumeReads(toChannelBuffer(buf));
verify(transport, times(0)).dispatchBadRequest(any(), any(), any()); verify(transport, times(0)).incomingRequestError(any(), any(), any());
verify(transport, times(0)).dispatchRequest(any(), any()); verify(transport, times(0)).incomingRequest(any(), any());
List<FlushOperation> flushOperations = handler.pollFlushOperations(); List<FlushOperation> flushOperations = handler.pollFlushOperations();
assertFalse(flushOperations.isEmpty()); assertFalse(flushOperations.isEmpty());
FlushOperation flushOperation = flushOperations.get(0); FlushOperation flushOperation = flushOperations.get(0);
HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status());
flushOperation.getListener().accept(null, null); flushOperation.getListener().accept(null, null);
// Since we have keep-alive set to false, we should close the channel after the response has been // Since we have keep-alive set to false, we should close the channel after the response has been
// flushed // flushed
verify(nioSocketChannel).close(); verify(nioHttpChannel).close();
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testEncodeHttpResponse() throws IOException { public void testEncodeHttpResponse() throws IOException {
prepareHandlerForResponse(handler); prepareHandlerForResponse(handler);
FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
NioHttpResponse pipelinedResponse = new NioHttpResponse(0, fullHttpResponse); NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0);
NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0");
SocketChannelContext context = mock(SocketChannelContext.class); SocketChannelContext context = mock(SocketChannelContext.class);
HttpWriteOperation writeOperation = new HttpWriteOperation(context, pipelinedResponse, mock(BiConsumer.class)); HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class));
List<FlushOperation> flushOperations = handler.writeToBytes(writeOperation); List<FlushOperation> flushOperations = handler.writeToBytes(writeOperation);
HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite()));
assertEquals(HttpResponseStatus.OK, response.status()); assertEquals(HttpResponseStatus.OK, response.status());
assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion());
} }
private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException { public void testCorsEnabledWithoutAllowOrigins() throws IOException {
HttpMethod method = HttpMethod.GET; // Set up a HTTP transport with only the CORS enabled setting
HttpVersion version = HttpVersion.HTTP_1_1; Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() throws IOException {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() throws IOException {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeCorsRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() throws IOException {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() throws IOException {
final String originValue = NioCorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig);
prepareHandlerForResponse(handler);
DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
NioHttpRequest nioHttpRequest = new NioHttpRequest(httpRequest, 0);
BytesArray content = new BytesArray("content");
HttpResponse response = nioHttpRequest.createResponse(RestStatus.OK, content);
response.addHeader("Content-Length", Integer.toString(content.length()));
SocketChannelContext context = mock(SocketChannelContext.class);
List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));
FlushOperation flushOperation = flushOperations.get(0);
return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
}
private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {
HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD;
HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1;
String uri = "http://localhost:9090/" + randomAlphaOfLength(8); String uri = "http://localhost:9090/" + randomAlphaOfLength(8);
HttpRequest request = new DefaultFullHttpRequest(version, method, uri); io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri);
ByteBuf buf = requestEncoder.encode(request); ByteBuf buf = requestEncoder.encode(request);
handler.consumeReads(toChannelBuffer(buf)); handler.consumeReads(toChannelBuffer(buf));
ArgumentCaptor<RestRequest> requestCaptor = ArgumentCaptor.forClass(RestRequest.class); ArgumentCaptor<NioHttpRequest> requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class);
verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class));
NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); NioHttpRequest nioHttpRequest = requestCaptor.getValue();
FullHttpRequest requestParsed = nioHttpRequest.getRequest(); assertNotNull(nioHttpRequest);
assertNotNull(requestParsed); assertEquals(method.name(), nioHttpRequest.method().name());
assertEquals(requestParsed.method(), method); if (version == HttpVersion.HTTP_1_1) {
assertEquals(requestParsed.protocolVersion(), version); assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion());
assertEquals(requestParsed.uri(), uri); } else {
return requestParsed; assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion());
}
assertEquals(nioHttpRequest.uri(), uri);
return nioHttpRequest;
} }
private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { private InboundChannelBuffer toChannelBuffer(ByteBuf buf) {
@ -226,11 +349,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
return buffer; return buffer;
} }
private static final int MAX = 16 * 1024 * 1024;
private static class RequestEncoder { private static class RequestEncoder {
private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder()); private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder(), new HttpObjectAggregator(MAX));
private ByteBuf encode(HttpRequest httpRequest) { private ByteBuf encode(io.netty.handler.codec.http.HttpRequest httpRequest) {
requestEncoder.writeOutbound(httpRequest); requestEncoder.writeOutbound(httpRequest);
return requestEncoder.readOutbound(); return requestEncoder.readOutbound();
} }
@ -238,9 +363,9 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
private static class ResponseDecoder { private static class ResponseDecoder {
private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder()); private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(MAX));
private HttpResponse decode(ByteBuf response) { private FullHttpResponse decode(ByteBuf response) {
responseDecoder.writeInbound(response); responseDecoder.writeInbound(response);
return responseDecoder.readInbound(); return responseDecoder.readInbound();
} }

View File

@ -1,349 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.nio;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.nio.cors.NioCorsConfig;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.nio.NioSocketChannel;
import org.elasticsearch.nio.SocketChannelContext;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.function.BiConsumer;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class NioHttpChannelTests extends ESTestCase {
private ThreadPool threadPool;
private MockBigArrays bigArrays;
private NioSocketChannel nioChannel;
private SocketChannelContext channelContext;
@Before
public void setup() throws Exception {
nioChannel = mock(NioSocketChannel.class);
channelContext = mock(SocketChannelContext.class);
when(nioChannel.getContext()).thenReturn(channelContext);
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(ByteBufUtils.toByteBuf(new TestResponse().content())));
}
public void testCorsEnabledWithoutAllowOrigins() {
// Set up a HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create a http transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create a http transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = NioCorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
public void testHeadersSet() {
Settings settings = Settings.builder().build();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote");
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
// send a response
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, corsConfig,
threadPool.getThreadContext());
TestResponse resp = new TestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(channelContext).sendMessage(responseCaptor.capture(), any());
Object nioResponse = responseCaptor.getValue();
HttpResponse response = ((NioHttpResponse) nioResponse).getResponse();
assertThat(response.headers().get("non-existent-header"), nullValue());
assertThat(response.headers().get(customHeader), equalTo(customHeaderValue));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length())));
assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType()));
}
@SuppressWarnings("unchecked")
public void testReleaseInListener() throws IOException {
final Settings settings = Settings.builder().build();
final NamedXContentRegistry registry = xContentRegistry();
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
final NioHttpRequest request = new NioHttpRequest(registry, httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings,
corsConfig, threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
Class<BiConsumer<Void, Exception>> listenerClass = (Class<BiConsumer<Void, Exception>>) (Class) BiConsumer.class;
ArgumentCaptor<BiConsumer<Void, Exception>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(channelContext).sendMessage(any(), listenerCaptor.capture());
BiConsumer<Void, Exception> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.accept(null, null);
} else {
listener.accept(null, new ClosedChannelException());
}
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
@SuppressWarnings("unchecked")
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
final FullHttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
}
} else {
httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/");
if (!close) {
httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
}
}
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings,
corsConfig, threadPool.getThreadContext());
final TestResponse resp = new TestResponse();
channel.sendResponse(resp);
Class<BiConsumer<Void, Exception>> listenerClass = (Class<BiConsumer<Void, Exception>>) (Class) BiConsumer.class;
ArgumentCaptor<BiConsumer<Void, Exception>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(channelContext).sendMessage(any(), listenerCaptor.capture());
BiConsumer<Void, Exception> listener = listenerCaptor.getValue();
listener.accept(null, null);
if (close) {
verify(nioChannel, times(1)).close();
} else {
verify(nioChannel, times(0)).close();
}
}
private FullHttpResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest);
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings);
NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, httpHandlingSettings, corsConfig,
threadPool.getThreadContext());
channel.sendResponse(new TestResponse());
// get the response
ArgumentCaptor<Object> responseCaptor = ArgumentCaptor.forClass(Object.class);
verify(channelContext, atLeastOnce()).sendMessage(responseCaptor.capture(), any());
return ((NioHttpResponse) responseCaptor.getValue()).getResponse();
}
private static class TestResponse extends RestResponse {
private final BytesReference reference;
TestResponse() {
reference = ByteBufUtils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8));
}
@Override
public String contentType() {
return "text";
}
@Override
public BytesReference content() {
return reference;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -19,15 +19,12 @@
package org.elasticsearch.http.nio; package org.elasticsearch.http.nio;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise; import io.netty.channel.ChannelPromise;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest; import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpMethod;
@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.LastHttpContent;
import io.netty.handler.codec.http.QueryStringDecoder; import io.netty.handler.codec.http.QueryStringDecoder;
import org.elasticsearch.common.Randomness; import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.junit.After; import org.junit.After;
@ -55,7 +55,6 @@ import java.util.stream.Collectors;
import java.util.stream.IntStream; import java.util.stream.IntStream;
import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH;
import static io.netty.handler.codec.http.HttpResponseStatus.OK;
import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1;
import static org.hamcrest.core.Is.is; import static org.hamcrest.core.Is.is;
@ -190,11 +189,11 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
ArrayList<ChannelPromise> promises = new ArrayList<>(); ArrayList<ChannelPromise> promises = new ArrayList<>();
for (int i = 1; i < requests.size(); ++i) { for (int i = 1; i < requests.size(); ++i) {
final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK);
ChannelPromise promise = embeddedChannel.newPromise(); ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise); promises.add(promise);
int sequence = requests.get(i).getSequence(); HttpPipelinedRequest<FullHttpRequest> pipelinedRequest = requests.get(i);
NioHttpResponse resp = new NioHttpResponse(sequence, httpResponse); NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
NioHttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.writeAndFlush(resp, promise); embeddedChannel.writeAndFlush(resp, promise);
} }
@ -231,10 +230,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
} }
private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<LastHttpContent>> { private class WorkEmulatorHandler extends SimpleChannelInboundHandler<HttpPipelinedRequest<FullHttpRequest>> {
@Override @Override
protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<LastHttpContent> pipelinedRequest) { protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest<FullHttpRequest> pipelinedRequest) {
LastHttpContent request = pipelinedRequest.getRequest(); LastHttpContent request = pipelinedRequest.getRequest();
final QueryStringDecoder decoder; final QueryStringDecoder decoder;
if (request instanceof FullHttpRequest) { if (request instanceof FullHttpRequest) {
@ -244,9 +243,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
} }
final String uri = decoder.path().replace("/", ""); final String uri = decoder.path().replace("/", "");
final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8));
final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence());
httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content);
httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length()));
final CountDownLatch waitingLatch = new CountDownLatch(1); final CountDownLatch waitingLatch = new CountDownLatch(1);
waitingRequests.put(uri, waitingLatch); waitingRequests.put(uri, waitingLatch);
@ -258,7 +258,7 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase {
waitingLatch.await(1000, TimeUnit.SECONDS); waitingLatch.await(1000, TimeUnit.SECONDS);
final ChannelPromise promise = ctx.newPromise(); final ChannelPromise promise = ctx.newPromise();
eventLoopService.submit(() -> { eventLoopService.submit(() -> {
ctx.write(new NioHttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); ctx.write(httpResponse, promise);
finishingLatch.countDown(); finishingLatch.countDown();
}); });
} catch (InterruptedException e) { } catch (InterruptedException e) {

View File

@ -280,40 +280,6 @@ public class NioHttpServerTransportTests extends ESTestCase {
assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); assertThat(causeReference.get(), instanceOf(TooLongFrameException.class));
} }
public void testDispatchDoesNotModifyThreadContext() throws InterruptedException {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (NioHttpServerTransport transport =
new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) {
transport.start();
transport.dispatchRequest(null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchBadRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
// public void testReadTimeout() throws Exception { // public void testReadTimeout() throws Exception {
// final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { // final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
// //

View File

@ -5,6 +5,20 @@ subprojects { Project subproj ->
subproj.tasks.withType(RestIntegTestTask) { subproj.tasks.withType(RestIntegTestTask) {
subproj.extensions.configure("${it.name}Cluster") { cluster -> subproj.extensions.configure("${it.name}Cluster") { cluster ->
cluster.distribution = System.getProperty('tests.distribution', 'oss-zip') cluster.distribution = System.getProperty('tests.distribution', 'oss-zip')
if (cluster.distribution == 'zip') {
/*
* Add Elastic's repositories so we can resolve older versions of the
* default distribution. Those aren't in maven central.
*/
repositories {
maven {
url "https://artifacts.elastic.co/maven"
}
maven {
url "https://snapshots.elastic.co/maven"
}
}
}
} }
} }
} }

View File

@ -50,7 +50,6 @@ import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -365,28 +364,11 @@ public class MasterService extends AbstractLifecycleComponent {
} }
public Discovery.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) { public Discovery.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) {
ArrayList<Discovery.AckListener> ackListeners = new ArrayList<>(); return new DelegatingAckListener(nonFailedTasks.stream()
.filter(task -> task.listener instanceof AckedClusterStateTaskListener)
//timeout straightaway, otherwise we could wait forever as the timeout thread has not started .map(task -> new AckCountDownListener((AckedClusterStateTaskListener) task.listener, newClusterState.version(),
nonFailedTasks.stream().filter(task -> task.listener instanceof AckedClusterStateTaskListener).forEach(task -> { newClusterState.nodes(), threadPool))
final AckedClusterStateTaskListener ackedListener = (AckedClusterStateTaskListener) task.listener; .collect(Collectors.toList()));
if (ackedListener.ackTimeout() == null || ackedListener.ackTimeout().millis() == 0) {
ackedListener.onAckTimeout();
} else {
try {
ackListeners.add(new AckCountDownListener(ackedListener, newClusterState.version(), newClusterState.nodes(),
threadPool));
} catch (EsRejectedExecutionException ex) {
if (logger.isDebugEnabled()) {
logger.debug("Couldn't schedule timeout thread - node might be shutting down", ex);
}
//timeout straightaway, otherwise we could wait forever as the timeout thread has not started
ackedListener.onAckTimeout();
}
}
});
return new DelegatingAckListener(ackListeners);
} }
public boolean clusterStateUnchanged() { public boolean clusterStateUnchanged() {
@ -549,6 +531,13 @@ public class MasterService extends AbstractLifecycleComponent {
this.listeners = listeners; this.listeners = listeners;
} }
@Override
public void onCommit(TimeValue commitTime) {
for (Discovery.AckListener listener : listeners) {
listener.onCommit(commitTime);
}
}
@Override @Override
public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { public void onNodeAck(DiscoveryNode node, @Nullable Exception e) {
for (Discovery.AckListener listener : listeners) { for (Discovery.AckListener listener : listeners) {
@ -564,14 +553,16 @@ public class MasterService extends AbstractLifecycleComponent {
private final AckedClusterStateTaskListener ackedTaskListener; private final AckedClusterStateTaskListener ackedTaskListener;
private final CountDown countDown; private final CountDown countDown;
private final DiscoveryNode masterNode; private final DiscoveryNode masterNode;
private final ThreadPool threadPool;
private final long clusterStateVersion; private final long clusterStateVersion;
private final Future<?> ackTimeoutCallback; private volatile Future<?> ackTimeoutCallback;
private Exception lastFailure; private Exception lastFailure;
AckCountDownListener(AckedClusterStateTaskListener ackedTaskListener, long clusterStateVersion, DiscoveryNodes nodes, AckCountDownListener(AckedClusterStateTaskListener ackedTaskListener, long clusterStateVersion, DiscoveryNodes nodes,
ThreadPool threadPool) { ThreadPool threadPool) {
this.ackedTaskListener = ackedTaskListener; this.ackedTaskListener = ackedTaskListener;
this.clusterStateVersion = clusterStateVersion; this.clusterStateVersion = clusterStateVersion;
this.threadPool = threadPool;
this.masterNode = nodes.getMasterNode(); this.masterNode = nodes.getMasterNode();
int countDown = 0; int countDown = 0;
for (DiscoveryNode node : nodes) { for (DiscoveryNode node : nodes) {
@ -581,8 +572,27 @@ public class MasterService extends AbstractLifecycleComponent {
} }
} }
logger.trace("expecting {} acknowledgements for cluster_state update (version: {})", countDown, clusterStateVersion); logger.trace("expecting {} acknowledgements for cluster_state update (version: {})", countDown, clusterStateVersion);
this.countDown = new CountDown(countDown); this.countDown = new CountDown(countDown + 1); // we also wait for onCommit to be called
this.ackTimeoutCallback = threadPool.schedule(ackedTaskListener.ackTimeout(), ThreadPool.Names.GENERIC, () -> onTimeout()); }
@Override
public void onCommit(TimeValue commitTime) {
TimeValue ackTimeout = ackedTaskListener.ackTimeout();
if (ackTimeout == null) {
ackTimeout = TimeValue.ZERO;
}
final TimeValue timeLeft = TimeValue.timeValueNanos(Math.max(0, ackTimeout.nanos() - commitTime.nanos()));
if (timeLeft.nanos() == 0L) {
onTimeout();
} else if (countDown.countDown()) {
finish();
} else {
this.ackTimeoutCallback = threadPool.schedule(timeLeft, ThreadPool.Names.GENERIC, this::onTimeout);
// re-check if onNodeAck has not completed while we were scheduling the timeout
if (countDown.isCountedDown()) {
FutureUtils.cancel(ackTimeoutCallback);
}
}
} }
@Override @Override
@ -599,11 +609,15 @@ public class MasterService extends AbstractLifecycleComponent {
} }
if (countDown.countDown()) { if (countDown.countDown()) {
finish();
}
}
private void finish() {
logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion);
FutureUtils.cancel(ackTimeoutCallback); FutureUtils.cancel(ackTimeoutCallback);
ackedTaskListener.onAllNodesAcked(lastFailure); ackedTaskListener.onAllNodesAcked(lastFailure);
} }
}
public void onTimeout() { public void onTimeout() {
if (countDown.fastForward()) { if (countDown.fastForward()) {

View File

@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.component.LifecycleComponent;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.TimeValue;
import java.io.IOException; import java.io.IOException;
@ -48,6 +49,19 @@ public interface Discovery extends LifecycleComponent {
void publish(ClusterChangedEvent clusterChangedEvent, AckListener ackListener); void publish(ClusterChangedEvent clusterChangedEvent, AckListener ackListener);
interface AckListener { interface AckListener {
/**
* Should be called when the discovery layer has committed the clusters state (i.e. even if this publication fails,
* it is guaranteed to appear in future publications).
* @param commitTime the time it took to commit the cluster state
*/
void onCommit(TimeValue commitTime);
/**
* Should be called whenever the discovery layer receives confirmation from a node that it has successfully applied
* the cluster state. In case of failures, an exception should be provided as parameter.
* @param node the node
* @param e the optional exception
*/
void onNodeAck(DiscoveryNode node, @Nullable Exception e); void onNodeAck(DiscoveryNode node, @Nullable Exception e);
} }

View File

@ -30,6 +30,7 @@ import org.elasticsearch.cluster.service.ClusterApplier.ClusterApplyListener;
import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.cluster.service.MasterService;
import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.discovery.DiscoveryStats; import org.elasticsearch.discovery.DiscoveryStats;
import org.elasticsearch.transport.TransportService; import org.elasticsearch.transport.TransportService;
@ -61,6 +62,7 @@ public class SingleNodeDiscovery extends AbstractLifecycleComponent implements D
public synchronized void publish(final ClusterChangedEvent event, public synchronized void publish(final ClusterChangedEvent event,
final AckListener ackListener) { final AckListener ackListener) {
clusterState = event.state(); clusterState = event.state();
ackListener.onCommit(TimeValue.ZERO);
CountDownLatch latch = new CountDownLatch(1); CountDownLatch latch = new CountDownLatch(1);
ClusterApplyListener listener = new ClusterApplyListener() { ClusterApplyListener listener = new ClusterApplyListener() {

View File

@ -158,7 +158,8 @@ public class PublishClusterStateAction extends AbstractComponent {
} }
try { try {
innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, sendFullVersion, serializedStates, serializedDiffs); innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, ackListener, sendFullVersion, serializedStates,
serializedDiffs);
} catch (Discovery.FailedToCommitClusterStateException t) { } catch (Discovery.FailedToCommitClusterStateException t) {
throw t; throw t;
} catch (Exception e) { } catch (Exception e) {
@ -173,8 +174,9 @@ public class PublishClusterStateAction extends AbstractComponent {
} }
private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final Set<DiscoveryNode> nodesToPublishTo, private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final Set<DiscoveryNode> nodesToPublishTo,
final SendingController sendingController, final boolean sendFullVersion, final SendingController sendingController, final Discovery.AckListener ackListener,
final Map<Version, BytesReference> serializedStates, final Map<Version, BytesReference> serializedDiffs) { final boolean sendFullVersion, final Map<Version, BytesReference> serializedStates,
final Map<Version, BytesReference> serializedDiffs) {
final ClusterState clusterState = clusterChangedEvent.state(); final ClusterState clusterState = clusterChangedEvent.state();
final ClusterState previousState = clusterChangedEvent.previousState(); final ClusterState previousState = clusterChangedEvent.previousState();
@ -195,8 +197,12 @@ public class PublishClusterStateAction extends AbstractComponent {
sendingController.waitForCommit(discoverySettings.getCommitTimeout()); sendingController.waitForCommit(discoverySettings.getCommitTimeout());
final long commitTime = System.nanoTime() - publishingStartInNanos;
ackListener.onCommit(TimeValue.timeValueNanos(commitTime));
try { try {
long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - (System.nanoTime() - publishingStartInNanos)); long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - commitTime);
final BlockingClusterStatePublishResponseHandler publishResponseHandler = sendingController.getPublishResponseHandler(); final BlockingClusterStatePublishResponseHandler publishResponseHandler = sendingController.getPublishResponseHandler();
sendingController.setPublishingTimedOut(!publishResponseHandler.awaitAllNodes(TimeValue.timeValueNanos(timeLeftInNanos))); sendingController.setPublishingTimedOut(!publishResponseHandler.awaitAllNodes(TimeValue.timeValueNanos(timeLeftInNanos)));
if (sendingController.getPublishingTimedOut()) { if (sendingController.getPublishingTimedOut()) {

View File

@ -29,6 +29,7 @@ import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.OutputStreamIndexOutput; import org.apache.lucene.store.OutputStreamIndexOutput;
import org.apache.lucene.store.SimpleFSDirectory; import org.apache.lucene.store.SimpleFSDirectory;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
@ -76,6 +77,7 @@ public abstract class MetaDataStateFormat<T> {
private final String prefix; private final String prefix;
private final Pattern stateFilePattern; private final Pattern stateFilePattern;
private static final Logger logger = Loggers.getLogger(MetaDataStateFormat.class);
/** /**
* Creates a new {@link MetaDataStateFormat} instance * Creates a new {@link MetaDataStateFormat} instance
@ -134,6 +136,7 @@ public abstract class MetaDataStateFormat<T> {
IOUtils.fsync(tmpStatePath, false); // fsync the state file IOUtils.fsync(tmpStatePath, false); // fsync the state file
Files.move(tmpStatePath, finalStatePath, StandardCopyOption.ATOMIC_MOVE); Files.move(tmpStatePath, finalStatePath, StandardCopyOption.ATOMIC_MOVE);
IOUtils.fsync(stateLocation, true); IOUtils.fsync(stateLocation, true);
logger.trace("written state to {}", finalStatePath);
for (int i = 1; i < locations.length; i++) { for (int i = 1; i < locations.length; i++) {
stateLocation = locations[i].resolve(STATE_DIR_NAME); stateLocation = locations[i].resolve(STATE_DIR_NAME);
Files.createDirectories(stateLocation); Files.createDirectories(stateLocation);
@ -145,12 +148,15 @@ public abstract class MetaDataStateFormat<T> {
// we are on the same FileSystem / Partition here we can do an atomic move // we are on the same FileSystem / Partition here we can do an atomic move
Files.move(tmpPath, finalPath, StandardCopyOption.ATOMIC_MOVE); Files.move(tmpPath, finalPath, StandardCopyOption.ATOMIC_MOVE);
IOUtils.fsync(stateLocation, true); IOUtils.fsync(stateLocation, true);
logger.trace("copied state to {}", finalPath);
} finally { } finally {
Files.deleteIfExists(tmpPath); Files.deleteIfExists(tmpPath);
logger.trace("cleaned up {}", tmpPath);
} }
} }
} finally { } finally {
Files.deleteIfExists(tmpStatePath); Files.deleteIfExists(tmpStatePath);
logger.trace("cleaned up {}", tmpStatePath);
} }
cleanupOldFiles(prefix, fileName, locations); cleanupOldFiles(prefix, fileName, locations);
} }
@ -211,20 +217,19 @@ public abstract class MetaDataStateFormat<T> {
} }
private void cleanupOldFiles(final String prefix, final String currentStateFile, Path[] locations) throws IOException { private void cleanupOldFiles(final String prefix, final String currentStateFile, Path[] locations) throws IOException {
final DirectoryStream.Filter<Path> filter = new DirectoryStream.Filter<Path>() { final DirectoryStream.Filter<Path> filter = entry -> {
@Override
public boolean accept(Path entry) throws IOException {
final String entryFileName = entry.getFileName().toString(); final String entryFileName = entry.getFileName().toString();
return Files.isRegularFile(entry) return Files.isRegularFile(entry)
&& entryFileName.startsWith(prefix) // only state files && entryFileName.startsWith(prefix) // only state files
&& currentStateFile.equals(entryFileName) == false; // keep the current state file around && currentStateFile.equals(entryFileName) == false; // keep the current state file around
}
}; };
// now clean up the old files // now clean up the old files
for (Path dataLocation : locations) { for (Path dataLocation : locations) {
logger.trace("cleanupOldFiles: cleaning up {}", dataLocation);
try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataLocation.resolve(STATE_DIR_NAME), filter)) { try (DirectoryStream<Path> stream = Files.newDirectoryStream(dataLocation.resolve(STATE_DIR_NAME), filter)) {
for (Path stateFile : stream) { for (Path stateFile : stream) {
Files.deleteIfExists(stateFile); Files.deleteIfExists(stateFile);
logger.trace("cleanupOldFiles: cleaned up {}", stateFile);
} }
} }
} }

View File

@ -123,6 +123,7 @@ public class MetaStateService extends AbstractComponent {
try { try {
IndexMetaData.FORMAT.write(indexMetaData, IndexMetaData.FORMAT.write(indexMetaData,
nodeEnv.indexPaths(indexMetaData.getIndex())); nodeEnv.indexPaths(indexMetaData.getIndex()));
logger.trace("[{}] state written", index);
} catch (Exception ex) { } catch (Exception ex) {
logger.warn(() -> new ParameterizedMessage("[{}]: failed to write index state", index), ex); logger.warn(() -> new ParameterizedMessage("[{}]: failed to write index state", index), ex);
throw new IOException("failed to write state for [" + index + "]", ex); throw new IOException("failed to write state for [" + index + "]", ex);
@ -136,6 +137,7 @@ public class MetaStateService extends AbstractComponent {
logger.trace("[_global] writing state, reason [{}]", reason); logger.trace("[_global] writing state, reason [{}]", reason);
try { try {
MetaData.FORMAT.write(metaData, nodeEnv.nodeDataPaths()); MetaData.FORMAT.write(metaData, nodeEnv.nodeDataPaths());
logger.trace("[_global] state written");
} catch (Exception ex) { } catch (Exception ex) {
logger.warn("[_global]: failed to write global state", ex); logger.warn("[_global]: failed to write global state", ex);
throw new IOException("failed to write global state", ex); throw new IOException("failed to write global state", ex);

View File

@ -21,6 +21,7 @@ package org.elasticsearch.http;
import com.carrotsearch.hppc.IntHashSet; import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntSet; import com.carrotsearch.hppc.IntSet;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkService;
@ -29,7 +30,9 @@ import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.PortsRange; import org.elasticsearch.common.transport.PortsRange;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
@ -48,11 +51,14 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT;
public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport { public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport {
public final HttpHandlingSettings handlingSettings;
protected final NetworkService networkService; protected final NetworkService networkService;
protected final BigArrays bigArrays;
protected final ThreadPool threadPool; protected final ThreadPool threadPool;
protected final Dispatcher dispatcher; protected final Dispatcher dispatcher;
private final NamedXContentRegistry xContentRegistry;
protected final String[] bindHosts; protected final String[] bindHosts;
protected final String[] publishHosts; protected final String[] publishHosts;
@ -61,11 +67,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
protected volatile BoundTransportAddress boundAddress; protected volatile BoundTransportAddress boundAddress;
protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) { protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool,
NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) {
super(settings); super(settings);
this.networkService = networkService; this.networkService = networkService;
this.bigArrays = bigArrays;
this.threadPool = threadPool; this.threadPool = threadPool;
this.xContentRegistry = xContentRegistry;
this.dispatcher = dispatcher; this.dispatcher = dispatcher;
this.handlingSettings = HttpHandlingSettings.fromSettings(settings);
// we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here
List<String> httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); List<String> httpBindHost = SETTING_HTTP_BIND_HOST.get(settings);
@ -156,17 +166,94 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
return publishPort; return publishPort;
} }
public void dispatchRequest(final RestRequest request, final RestChannel channel) { /**
* This method handles an incoming http request.
*
* @param httpRequest that is incoming
* @param httpChannel that received the http request
*/
public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) {
handleIncomingRequest(httpRequest, httpChannel, null);
}
/**
* This method handles an incoming http request that has encountered an error.
*
* @param httpRequest that is incoming
* @param httpChannel that received the http request
* @param exception that was encountered
*/
public void incomingRequestError(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
handleIncomingRequest(httpRequest, httpChannel, exception);
}
// Visible for testing
void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) {
final ThreadContext threadContext = threadPool.getThreadContext(); final ThreadContext threadContext = threadPool.getThreadContext();
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
dispatcher.dispatchRequest(request, channel, threadContext); if (badRequestCause != null) {
dispatcher.dispatchBadRequest(restRequest, channel, threadContext, badRequestCause);
} else {
dispatcher.dispatchRequest(restRequest, channel, threadContext);
}
} }
} }
public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
final ThreadContext threadContext = threadPool.getThreadContext(); Exception badRequestCause = exception;
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
dispatcher.dispatchBadRequest(request, channel, threadContext, cause); /*
* We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there
* are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we
* attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header,
* or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the
* underlying exception that caused us to treat the request as bad.
*/
final RestRequest restRequest;
{
RestRequest innerRestRequest;
try {
innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel);
} catch (final RestRequest.ContentTypeHeaderException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause);
} catch (final RestRequest.BadParameterException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
}
restRequest = innerRestRequest;
}
/*
* We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid
* parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an
* IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these
* parameter values.
*/
final RestChannel channel;
{
RestChannel innerChannel;
ThreadContext threadContext = threadPool.getThreadContext();
try {
innerChannel = new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext);
} catch (final IllegalArgumentException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
innerChannel = new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext);
}
channel = innerChannel;
}
dispatchRequest(restRequest, channel, badRequestCause);
}
private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) {
HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type");
try {
return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel);
} catch (final RestRequest.BadParameterException e) {
badRequestCause.addSuppressed(e);
return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel);
} }
} }
} }

View File

@ -0,0 +1,172 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.AbstractRestChannel;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* The default rest channel for incoming requests. This class implements the basic logic for sending a rest
* response. It will set necessary headers nad ensure that bytes are released after the response is sent.
*/
public class DefaultRestChannel extends AbstractRestChannel implements RestChannel {
static final String CLOSE = "close";
static final String CONNECTION = "connection";
static final String KEEP_ALIVE = "keep-alive";
static final String CONTENT_TYPE = "content-type";
static final String CONTENT_LENGTH = "content-length";
static final String SET_COOKIE = "set-cookie";
static final String X_OPAQUE_ID = "X-Opaque-Id";
private final HttpRequest httpRequest;
private final BigArrays bigArrays;
private final HttpHandlingSettings settings;
private final ThreadContext threadContext;
private final HttpChannel httpChannel;
DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays,
HttpHandlingSettings settings, ThreadContext threadContext) {
super(request, settings.getDetailedErrorsEnabled());
this.httpChannel = httpChannel;
this.httpRequest = httpRequest;
this.bigArrays = bigArrays;
this.settings = settings;
this.threadContext = threadContext;
}
@Override
protected BytesStreamOutput newBytesOutput() {
return new ReleasableBytesStreamOutput(bigArrays);
}
@Override
public void sendResponse(RestResponse restResponse) {
HttpResponse httpResponse;
if (RestRequest.Method.HEAD == request.method()) {
httpResponse = httpRequest.createResponse(restResponse.status(), BytesArray.EMPTY);
} else {
httpResponse = httpRequest.createResponse(restResponse.status(), restResponse.content());
}
// TODO: Ideally we should move the setting of Cors headers into :server
// NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
String opaque = request.header(X_OPAQUE_ID);
if (opaque != null) {
setHeaderField(httpResponse, X_OPAQUE_ID, opaque);
}
// Add all custom headers
addCustomHeaders(httpResponse, restResponse.getHeaders());
addCustomHeaders(httpResponse, threadContext.getResponseHeaders());
ArrayList<Releasable> toClose = new ArrayList<>(3);
boolean success = false;
try {
// If our response doesn't specify a content-type header, set one
setHeaderField(httpResponse, CONTENT_TYPE, restResponse.contentType(), false);
// If our response has no content-length, calculate and set one
setHeaderField(httpResponse, CONTENT_LENGTH, String.valueOf(restResponse.content().length()), false);
addCookies(httpResponse);
BytesReference content = restResponse.content();
if (content instanceof Releasable) {
toClose.add((Releasable) content);
}
BytesStreamOutput bytesStreamOutput = bytesOutputOrNull();
if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) {
toClose.add((Releasable) bytesStreamOutput);
}
if (isCloseConnection()) {
toClose.add(httpChannel);
}
ActionListener<Void> listener = ActionListener.wrap(() -> Releasables.close(toClose));
httpChannel.sendResponse(httpResponse, listener);
success = true;
} finally {
if (success == false) {
Releasables.close(toClose);
}
}
}
private void setHeaderField(HttpResponse response, String headerField, String value) {
setHeaderField(response, headerField, value, true);
}
private void setHeaderField(HttpResponse response, String headerField, String value, boolean override) {
if (override || !response.containsHeader(headerField)) {
response.addHeader(headerField, value);
}
}
private void addCustomHeaders(HttpResponse response, Map<String, List<String>> customHeaders) {
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
setHeaderField(response, headerEntry.getKey(), headerValue);
}
}
}
}
private void addCookies(HttpResponse response) {
if (settings.isResetCookies()) {
List<String> cookies = request.getHttpRequest().strictCookies();
if (cookies.isEmpty() == false) {
for (String cookie : cookies) {
response.addHeader(SET_COOKIE, cookie);
}
}
}
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
final boolean http10 = isHttp10();
return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION)));
}
// Determine if the request protocol version is HTTP 1.0
private boolean isHttp10() {
return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
}
}

View File

@ -0,0 +1,58 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.lease.Releasable;
import java.net.InetSocketAddress;
public interface HttpChannel extends Releasable {
/**
* Sends a http response to the channel. The listener will be executed once the send process has been
* completed.
*
* @param response to send to channel
* @param listener to execute upon send completion
*/
void sendResponse(HttpResponse response, ActionListener<Void> listener);
/**
* Returns the local address for this channel.
*
* @return the local address of this channel.
*/
InetSocketAddress getLocalAddress();
/**
* Returns the remote address for this channel. Can be null if channel does not have a remote address.
*
* @return the remote address of this channel.
*/
InetSocketAddress getRemoteAddress();
/**
* Closes the channel. This might be an asynchronous process. There is no guarantee that the channel
* will be closed when this method returns.
*/
void close();
}

View File

@ -18,20 +18,17 @@
*/ */
package org.elasticsearch.http; package org.elasticsearch.http;
public class HttpPipelinedMessage implements Comparable<HttpPipelinedMessage> { public interface HttpPipelinedMessage extends Comparable<HttpPipelinedMessage> {
private final int sequence; /**
* Get the sequence number for this message.
public HttpPipelinedMessage(int sequence) { *
this.sequence = sequence; * @return the sequence number
} */
int getSequence();
public int getSequence() {
return sequence;
}
@Override @Override
public int compareTo(HttpPipelinedMessage o) { default int compareTo(HttpPipelinedMessage o) {
return Integer.compare(sequence, o.sequence); return Integer.compare(getSequence(), o.getSequence());
} }
} }

View File

@ -18,15 +18,21 @@
*/ */
package org.elasticsearch.http; package org.elasticsearch.http;
public class HttpPipelinedRequest<R> extends HttpPipelinedMessage { public class HttpPipelinedRequest<R> implements HttpPipelinedMessage {
private final R request; private final R request;
private final int sequence;
HttpPipelinedRequest(int sequence, R request) { HttpPipelinedRequest(int sequence, R request) {
super(sequence); this.sequence = sequence;
this.request = request; this.request = request;
} }
@Override
public int getSequence() {
return sequence;
}
public R getRequest() { public R getRequest() {
return request; return request;
} }

View File

@ -0,0 +1,65 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.List;
import java.util.Map;
/**
* A basic http request abstraction. Http modules needs to implement this interface to integrate with the
* server package's rest handling.
*/
public interface HttpRequest {
enum HttpVersion {
HTTP_1_0,
HTTP_1_1
}
RestRequest.Method method();
/**
* The uri of the rest request, with the query string.
*/
String uri();
BytesReference content();
/**
* Get all of the headers and values associated with the headers. Modifications of this map are not supported.
*/
Map<String, List<String>> getHeaders();
List<String> strictCookies();
HttpVersion protocolVersion();
HttpRequest removeHeader(String header);
/**
* Create an http response from this request and the supplied status and content.
*/
HttpResponse createResponse(RestStatus status, BytesReference content);
}

View File

@ -0,0 +1,32 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
/**
* A basic http response abstraction. Http modules must implement this interface as the server package rest
* handling needs to set http headers for a response.
*/
public interface HttpResponse {
void addHeader(String name, String value);
boolean containsHeader(String name);
}

View File

@ -40,7 +40,7 @@ public abstract class AbstractRestChannel implements RestChannel {
private static final Predicate<String> EXCLUDE_FILTER = INCLUDE_FILTER.negate(); private static final Predicate<String> EXCLUDE_FILTER = INCLUDE_FILTER.negate();
protected final RestRequest request; protected final RestRequest request;
protected final boolean detailedErrorsEnabled; private final boolean detailedErrorsEnabled;
private final String format; private final String format;
private final String filterPath; private final String filterPath;
private final boolean pretty; private final boolean pretty;

View File

@ -272,8 +272,9 @@ public class RestController extends AbstractComponent implements HttpServerTrans
*/ */
private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) { private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) {
if (restRequest.getXContentType() == null) { if (restRequest.getXContentType() == null) {
if (restHandler.supportsContentStream() && restRequest.header("Content-Type") != null) { String contentTypeHeader = restRequest.header("Content-Type");
final String lowercaseMediaType = restRequest.header("Content-Type").toLowerCase(Locale.ROOT); if (restHandler.supportsContentStream() && contentTypeHeader != null) {
final String lowercaseMediaType = contentTypeHeader.toLowerCase(Locale.ROOT);
// we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/ // we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/
if (lowercaseMediaType.equals("application/x-ndjson")) { if (lowercaseMediaType.equals("application/x-ndjson")) {
restRequest.setXContentType(XContentType.JSON); restRequest.setXContentType(XContentType.JSON);

View File

@ -35,10 +35,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.net.SocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
@ -51,7 +52,7 @@ import java.util.stream.Collectors;
import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue;
import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; import static org.elasticsearch.common.unit.TimeValue.parseTimeValue;
public abstract class RestRequest implements ToXContent.Params { public class RestRequest implements ToXContent.Params {
// tchar pattern as defined by RFC7230 section 3.2.6 // tchar pattern as defined by RFC7230 section 3.2.6
private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+"); private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+");
@ -62,18 +63,47 @@ public abstract class RestRequest implements ToXContent.Params {
private final String rawPath; private final String rawPath;
private final Set<String> consumedParams = new HashSet<>(); private final Set<String> consumedParams = new HashSet<>();
private final SetOnce<XContentType> xContentType = new SetOnce<>(); private final SetOnce<XContentType> xContentType = new SetOnce<>();
private final HttpRequest httpRequest;
private final HttpChannel httpChannel;
protected RestRequest(NamedXContentRegistry xContentRegistry, Map<String, String> params, String path,
Map<String, List<String>> headers, HttpRequest httpRequest, HttpChannel httpChannel) {
final XContentType xContentType;
try {
xContentType = parseContentType(headers.get("Content-Type"));
} catch (final IllegalArgumentException e) {
throw new ContentTypeHeaderException(e);
}
if (xContentType != null) {
this.xContentType.set(xContentType);
}
this.xContentRegistry = xContentRegistry;
this.httpRequest = httpRequest;
this.httpChannel = httpChannel;
this.params = params;
this.rawPath = path;
this.headers = Collections.unmodifiableMap(headers);
}
protected RestRequest(RestRequest restRequest) {
this(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(),
restRequest.getHttpRequest(), restRequest.getHttpChannel());
}
/** /**
* Creates a new REST request. * Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be
* decoded
* *
* @param xContentRegistry the content registry * @param xContentRegistry the content registry
* @param uri the raw URI that will be parsed into the path and the parameters * @param httpRequest the http request
* @param headers a map of the header; this map should implement a case-insensitive lookup * @param httpChannel the http channel
* @throws BadParameterException if the parameters can not be decoded * @throws BadParameterException if the parameters can not be decoded
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/ */
public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map<String, List<String>> headers) { public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, HttpChannel httpChannel) {
this(xContentRegistry, params(uri), path(uri), headers); Map<String, String> params = params(httpRequest.uri());
String path = path(httpRequest.uri());
return new RestRequest(xContentRegistry, params, path, httpRequest.getHeaders(), httpRequest, httpChannel);
} }
private static Map<String, String> params(final String uri) { private static Map<String, String> params(final String uri) {
@ -99,46 +129,34 @@ public abstract class RestRequest implements ToXContent.Params {
} }
/** /**
* Creates a new REST request. In contrast to * Creates a new REST request. The path is not decoded so this constructor will not throw a
* {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw * {@link BadParameterException}.
* a {@link BadParameterException}.
* *
* @param xContentRegistry the content registry * @param xContentRegistry the content registry
* @param params the request parameters * @param httpRequest the http request
* @param path the raw path (which is not parsed) * @param httpChannel the http channel
* @param headers a map of the header; this map should implement a case-insensitive lookup
* @throws ContentTypeHeaderException if the Content-Type header can not be parsed * @throws ContentTypeHeaderException if the Content-Type header can not be parsed
*/ */
public RestRequest( public static RestRequest requestWithoutParameters(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest,
final NamedXContentRegistry xContentRegistry, HttpChannel httpChannel) {
final Map<String, String> params, Map<String, String> params = Collections.emptyMap();
final String path, return new RestRequest(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel);
final Map<String, List<String>> headers) {
final XContentType xContentType;
try {
xContentType = parseContentType(headers.get("Content-Type"));
} catch (final IllegalArgumentException e) {
throw new ContentTypeHeaderException(e);
}
if (xContentType != null) {
this.xContentType.set(xContentType);
}
this.xContentRegistry = xContentRegistry;
this.params = params;
this.rawPath = path;
this.headers = Collections.unmodifiableMap(headers);
} }
public enum Method { public enum Method {
GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT
} }
public abstract Method method(); public Method method() {
return httpRequest.method();
}
/** /**
* The uri of the rest request, with the query string. * The uri of the rest request, with the query string.
*/ */
public abstract String uri(); public String uri() {
return httpRequest.uri();
}
/** /**
* The non decoded, raw path provided. * The non decoded, raw path provided.
@ -154,9 +172,13 @@ public abstract class RestRequest implements ToXContent.Params {
return RestUtils.decodeComponent(rawPath()); return RestUtils.decodeComponent(rawPath());
} }
public abstract boolean hasContent(); public boolean hasContent() {
return content().length() > 0;
}
public abstract BytesReference content(); public BytesReference content() {
return httpRequest.content();
}
/** /**
* @return content of the request body or throw an exception if the body or content type is missing * @return content of the request body or throw an exception if the body or content type is missing
@ -216,14 +238,12 @@ public abstract class RestRequest implements ToXContent.Params {
this.xContentType.set(xContentType); this.xContentType.set(xContentType);
} }
@Nullable public HttpChannel getHttpChannel() {
public SocketAddress getRemoteAddress() { return httpChannel;
return null;
} }
@Nullable public HttpRequest getHttpRequest() {
public SocketAddress getLocalAddress() { return httpRequest;
return null;
} }
public final boolean hasParam(String key) { public final boolean hasParam(String key) {

View File

@ -20,10 +20,10 @@
package org.elasticsearch.rest; package org.elasticsearch.rest;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -31,8 +31,7 @@ import java.util.Set;
public abstract class RestResponse { public abstract class RestResponse {
protected Map<String, List<String>> customHeaders; private Map<String, List<String>> customHeaders;
/** /**
* The response content type. * The response content type.
@ -81,10 +80,13 @@ public abstract class RestResponse {
} }
/** /**
* Returns custom headers that have been added, or null if none have been set. * Returns custom headers that have been added. This method should not be used to mutate headers.
*/ */
@Nullable
public Map<String, List<String>> getHeaders() { public Map<String, List<String>> getHeaders() {
if (customHeaders == null) {
return Collections.emptyMap();
} else {
return customHeaders; return customHeaders;
} }
} }
}

View File

@ -22,6 +22,7 @@ import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
@ -39,6 +40,7 @@ import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.BaseFuture; import org.elasticsearch.common.util.concurrent.BaseFuture;
import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.MockLogAppender;
import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.junit.annotations.TestLogging;
@ -65,6 +67,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import static java.util.Collections.emptyMap; import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet; import static java.util.Collections.emptySet;
@ -680,6 +683,132 @@ public class MasterServiceTests extends ESTestCase {
mockAppender.assertAllExpectationsMatched(); mockAppender.assertAllExpectationsMatched();
} }
public void testAcking() throws InterruptedException {
final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
final DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
final DiscoveryNode node3 = new DiscoveryNode("node3", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT);
TimedMasterService timedMasterService = new TimedMasterService(Settings.builder().put("cluster.name",
MasterServiceTests.class.getSimpleName()).build(), threadPool);
ClusterState initialClusterState = ClusterState.builder(new ClusterName(MasterServiceTests.class.getSimpleName()))
.nodes(DiscoveryNodes.builder()
.add(node1)
.add(node2)
.add(node3)
.localNodeId(node1.getId())
.masterNodeId(node1.getId()))
.blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK).build();
final AtomicReference<BiConsumer<ClusterChangedEvent, Discovery.AckListener>> publisherRef = new AtomicReference<>();
timedMasterService.setClusterStatePublisher((cce, l) -> publisherRef.get().accept(cce, l));
timedMasterService.setClusterStateSupplier(() -> initialClusterState);
timedMasterService.start();
// check that we don't time out before even committing the cluster state
{
final CountDownLatch latch = new CountDownLatch(1);
publisherRef.set((clusterChangedEvent, ackListener) -> {
throw new Discovery.FailedToCommitClusterStateException("mock exception");
});
timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask<Void>(null, null) {
@Override
public ClusterState execute(ClusterState currentState) {
return ClusterState.builder(currentState).build();
}
@Override
public TimeValue ackTimeout() {
return TimeValue.ZERO;
}
@Override
public TimeValue timeout() {
return null;
}
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
fail();
}
@Override
protected Void newResponse(boolean acknowledged) {
fail();
return null;
}
@Override
public void onFailure(String source, Exception e) {
latch.countDown();
}
@Override
public void onAckTimeout() {
fail();
}
});
latch.await();
}
// check that we timeout if commit took too long
{
final CountDownLatch latch = new CountDownLatch(2);
final TimeValue ackTimeout = TimeValue.timeValueMillis(randomInt(100));
publisherRef.set((clusterChangedEvent, ackListener) -> {
ackListener.onCommit(TimeValue.timeValueMillis(ackTimeout.millis() + randomInt(100)));
ackListener.onNodeAck(node1, null);
ackListener.onNodeAck(node2, null);
ackListener.onNodeAck(node3, null);
});
timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask<Void>(null, null) {
@Override
public ClusterState execute(ClusterState currentState) {
return ClusterState.builder(currentState).build();
}
@Override
public TimeValue ackTimeout() {
return ackTimeout;
}
@Override
public TimeValue timeout() {
return null;
}
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
latch.countDown();
}
@Override
protected Void newResponse(boolean acknowledged) {
fail();
return null;
}
@Override
public void onFailure(String source, Exception e) {
fail();
}
@Override
public void onAckTimeout() {
latch.countDown();
}
});
latch.await();
}
timedMasterService.close();
}
static class TimedMasterService extends MasterService { static class TimedMasterService extends MasterService {
public volatile Long currentTimeOverride = null; public volatile Long currentTimeOverride = null;

View File

@ -42,6 +42,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.Discovery;
import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.discovery.DiscoverySettings;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
@ -815,9 +816,16 @@ public class PublishClusterStateActionTests extends ESTestCase {
public static class AssertingAckListener implements Discovery.AckListener { public static class AssertingAckListener implements Discovery.AckListener {
private final List<Tuple<DiscoveryNode, Throwable>> errors = new CopyOnWriteArrayList<>(); private final List<Tuple<DiscoveryNode, Throwable>> errors = new CopyOnWriteArrayList<>();
private final CountDownLatch countDown; private final CountDownLatch countDown;
private final CountDownLatch commitCountDown;
public AssertingAckListener(int nodeCount) { public AssertingAckListener(int nodeCount) {
countDown = new CountDownLatch(nodeCount); countDown = new CountDownLatch(nodeCount);
commitCountDown = new CountDownLatch(1);
}
@Override
public void onCommit(TimeValue commitTime) {
commitCountDown.countDown();
} }
@Override @Override
@ -830,6 +838,7 @@ public class PublishClusterStateActionTests extends ESTestCase {
public void await(long timeout, TimeUnit unit) throws InterruptedException { public void await(long timeout, TimeUnit unit) throws InterruptedException {
assertThat(awaitErrors(timeout, unit), emptyIterable()); assertThat(awaitErrors(timeout, unit), emptyIterable());
assertTrue(commitCountDown.await(timeout, unit));
} }
public List<Tuple<DiscoveryNode, Throwable>> awaitErrors(long timeout, TimeUnit unit) throws InterruptedException { public List<Tuple<DiscoveryNode, Throwable>> awaitErrors(long timeout, TimeUnit unit) throws InterruptedException {

View File

@ -19,13 +19,27 @@
package org.elasticsearch.http; package org.elasticsearch.http;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.network.NetworkUtils; import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import static java.net.InetAddress.getByName; import static java.net.InetAddress.getByName;
@ -36,6 +50,27 @@ import static org.hamcrest.Matchers.equalTo;
public class AbstractHttpServerTransportTests extends ESTestCase { public class AbstractHttpServerTransportTests extends ESTestCase {
private NetworkService networkService;
private ThreadPool threadPool;
private MockBigArrays bigArrays;
@Before
public void setup() throws Exception {
networkService = new NetworkService(Collections.emptyList());
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() throws Exception {
if (threadPool != null) {
threadPool.shutdownNow();
}
threadPool = null;
networkService = null;
bigArrays = null;
}
public void testHttpPublishPort() throws Exception { public void testHttpPublishPort() throws Exception {
int boundPort = randomIntBetween(9000, 9100); int boundPort = randomIntBetween(9000, 9100);
int otherBoundPort = randomIntBetween(9200, 9300); int otherBoundPort = randomIntBetween(9200, 9300);
@ -71,6 +106,64 @@ public class AbstractHttpServerTransportTests extends ESTestCase {
} }
} }
public void testDispatchDoesNotModifyThreadContext() {
final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() {
@Override
public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) {
threadContext.putHeader("foo", "bar");
threadContext.putTransient("bar", "baz");
}
@Override
public void dispatchBadRequest(final RestRequest request,
final RestChannel channel,
final ThreadContext threadContext,
final Throwable cause) {
threadContext.putHeader("foo_bad", "bar");
threadContext.putTransient("bar_bad", "baz");
}
};
try (AbstractHttpServerTransport transport =
new AbstractHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher) {
@Override
protected TransportAddress bindAddress(InetAddress hostAddress) {
return null;
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() throws IOException {
}
@Override
public HttpStats stats() {
return null;
}
}) {
transport.dispatchRequest(null, null, null);
assertNull(threadPool.getThreadContext().getHeader("foo"));
assertNull(threadPool.getThreadContext().getTransient("bar"));
transport.dispatchRequest(null, null, new Exception());
assertNull(threadPool.getThreadContext().getHeader("foo_bad"));
assertNull(threadPool.getThreadContext().getTransient("bar_bad"));
}
}
private TransportAddress address(String host, int port) throws UnknownHostException { private TransportAddress address(String host, int port) throws UnknownHostException {
return new TransportAddress(getByName(host), port); return new TransportAddress(getByName(host), port);
} }

View File

@ -0,0 +1,444 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.MockBigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestResponse;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
public class DefaultRestChannelTests extends ESTestCase {
private ThreadPool threadPool;
private MockBigArrays bigArrays;
private HttpChannel httpChannel;
@Before
public void setup() {
httpChannel = mock(HttpChannel.class);
threadPool = new TestThreadPool("test");
bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
}
@After
public void shutdown() {
if (threadPool != null) {
threadPool.shutdownNow();
}
}
public void testResponse() {
final TestResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(new TestRestResponse().content()));
}
// TODO: Enable these Cors tests when the Cors logic lives in :server
// public void testCorsEnabledWithoutAllowOrigins() {
// // Set up a HTTP transport with only the CORS enabled setting
// Settings settings = Settings.builder()
// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
// }
//
// public void testCorsEnabledWithAllowOrigins() {
// final String originValue = "remote-host";
// // create a http transport with CORS enabled and allow origin configured
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testCorsAllowOriginWithSameHost() {
// String originValue = "remote-host";
// String host = "remote-host";
// // create a http transport with CORS enabled
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, host);
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = "http://" + originValue;
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue + ":5555";
// host = host + ":5555";
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue.replace("http", "https");
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testThatStringLiteralWorksOnMatch() {
// final String originValue = "remote-host";
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
// }
//
// public void testThatAnyOriginWorks() {
// final String originValue = NioCorsHandler.ANY_ORIGIN;
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
// }
public void testHeadersSet() {
Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
TestRestResponse resp = new TestRestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
resp.addHeader(customHeader, customHeaderValue);
channel.sendResponse(resp);
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse httpResponse = responseCaptor.getValue();
Map<String, List<String>> headers = httpResponse.headers;
assertNull(headers.get("non-existent-header"));
assertEquals(customHeaderValue, headers.get(customHeader).get(0));
assertEquals("abc", headers.get(DefaultRestChannel.X_OPAQUE_ID).get(0));
assertEquals(Integer.toString(resp.content().length()), headers.get(DefaultRestChannel.CONTENT_LENGTH).get(0));
assertEquals(resp.contentType(), headers.get(DefaultRestChannel.CONTENT_TYPE).get(0));
}
public void testCookiesSet() {
Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse nioResponse = responseCaptor.getValue();
Map<String, List<String>> headers = nioResponse.headers;
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie"));
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2"));
}
@SuppressWarnings("unchecked")
public void testReleaseInListener() throws IOException {
final Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
// ensure we have reserved bytes
if (randomBoolean()) {
BytesStreamOutput out = channel.bytesOutput();
assertThat(out, instanceOf(ReleasableBytesStreamOutput.class));
} else {
try (XContentBuilder builder = channel.newBuilder()) {
// do something builder
builder.startObject().endObject();
}
}
channel.sendResponse(response);
Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(httpChannel).sendResponse(any(), listenerCaptor.capture());
ActionListener<Void> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new ClosedChannelException());
}
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
}
@SuppressWarnings("unchecked")
public void testConnectionClose() throws Exception {
final Settings settings = Settings.builder().build();
final HttpRequest httpRequest;
final boolean close = randomBoolean();
if (randomBoolean()) {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
if (close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE));
}
} else {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
if (!close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE));
}
}
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
verify(httpChannel).sendResponse(any(), listenerCaptor.capture());
ActionListener<Void> listener = listenerCaptor.getValue();
if (randomBoolean()) {
listener.onResponse(null);
} else {
listener.onFailure(new ClosedChannelException());
}
if (close) {
verify(httpChannel, times(1)).close();
} else {
verify(httpChannel, times(0)).close();
}
}
private TestResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private TestResponse executeRequest(final Settings settings, final String originValue, final String host) {
HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
// TODO: These exist for the Cors tests
// if (originValue != null) {
// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
// }
// httpRequest.headers().add(HttpHeaderNames.HOST, host);
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings,
threadPool.getThreadContext());
channel.sendResponse(new TestRestResponse());
// get the response
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any());
return responseCaptor.getValue();
}
private static class TestRequest implements HttpRequest {
private final HttpVersion version;
private final RestRequest.Method method;
private final String uri;
private HashMap<String, List<String>> headers = new HashMap<>();
private TestRequest(HttpVersion version, RestRequest.Method method, String uri) {
this.version = version;
this.method = method;
this.uri = uri;
}
@Override
public RestRequest.Method method() {
return method;
}
@Override
public String uri() {
return uri;
}
@Override
public BytesReference content() {
return BytesArray.EMPTY;
}
@Override
public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Arrays.asList("cookie", "cookie2");
}
@Override
public HttpVersion protocolVersion() {
return version;
}
@Override
public HttpRequest removeHeader(String header) {
throw new UnsupportedOperationException("Do not support removing header on test request.");
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return new TestResponse(status, content);
}
}
private static class TestResponse implements HttpResponse {
private final RestStatus status;
private final BytesReference content;
private final Map<String, List<String>> headers = new HashMap<>();
TestResponse(RestStatus status, BytesReference content) {
this.status = status;
this.content = content;
}
public String contentType() {
return "text";
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return status;
}
@Override
public void addHeader(String name, String value) {
if (headers.containsKey(name) == false) {
ArrayList<String> values = new ArrayList<>();
values.add(value);
headers.put(name, values);
} else {
headers.get(name).add(value);
}
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
}
private static class TestRestResponse extends RestResponse {
private final BytesReference content;
TestRestResponse() {
content = new BytesArray("content".getBytes(StandardCharsets.UTF_8));
}
public String contentType() {
return "text";
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -29,7 +29,6 @@ import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
@ -165,28 +164,7 @@ public class BytesRestResponseTests extends ESTestCase {
public void testResponseWhenPathContainsEncodingError() throws IOException { public void testResponseWhenPathContainsEncodingError() throws IOException {
final String path = "%a"; final String path = "%a";
final RestRequest request = final RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath(path).build();
new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) {
@Override
public Method method() {
return null;
}
@Override
public String uri() {
return null;
}
@Override
public boolean hasContent() {
return false;
}
@Override
public BytesReference content() {
return null;
}
};
final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath())); final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath()));
final RestChannel channel = new DetailedExceptionRestChannel(request); final RestChannel channel = new DetailedExceptionRestChannel(request);
// if we try to decode the path, this will throw an IllegalArgumentException again // if we try to decode the path, this will throw an IllegalArgumentException again

View File

@ -240,7 +240,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestAddsAndFreesBytesOnSuccess() { public void testDispatchRequestAddsAndFreesBytesOnSuccess() {
int contentLength = BREAKER_LIMIT.bytesAsInt(); int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength); String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); RestRequest request = testRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -252,7 +252,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestAddsAndFreesBytesOnError() { public void testDispatchRequestAddsAndFreesBytesOnError() {
int contentLength = BREAKER_LIMIT.bytesAsInt(); int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength); String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); RestRequest request = testRestRequest("/error", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -265,7 +265,7 @@ public class RestControllerTests extends ESTestCase {
int contentLength = BREAKER_LIMIT.bytesAsInt(); int contentLength = BREAKER_LIMIT.bytesAsInt();
String content = randomAlphaOfLength(contentLength); String content = randomAlphaOfLength(contentLength);
// we will produce an error in the rest handler and one more when sending the error response // we will produce an error in the rest handler and one more when sending the error response
TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); RestRequest request = testRestRequest("/error", content, XContentType.JSON);
ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true); ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -277,7 +277,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequestLimitsBytes() { public void testDispatchRequestLimitsBytes() {
int contentLength = BREAKER_LIMIT.bytesAsInt() + 1; int contentLength = BREAKER_LIMIT.bytesAsInt() + 1;
String content = randomAlphaOfLength(contentLength); String content = randomAlphaOfLength(contentLength);
TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); RestRequest request = testRestRequest("/", content, XContentType.JSON);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE);
restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY));
@ -288,7 +288,7 @@ public class RestControllerTests extends ESTestCase {
public void testDispatchRequiresContentTypeForRequestsWithContent() { public void testDispatchRequiresContentTypeForRequestsWithContent() {
String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt()); String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt());
TestRestRequest request = new TestRestRequest("/", content, null); RestRequest request = testRestRequest("/", content, null);
AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE);
restController = new RestController( restController = new RestController(
Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(), Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(),
@ -547,35 +547,11 @@ public class RestControllerTests extends ESTestCase {
} }
} }
private static final class TestRestRequest extends RestRequest { private static RestRequest testRestRequest(String path, String content, XContentType xContentType) {
FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY);
private final BytesReference content; builder.withPath(path);
builder.withContent(new BytesArray(content), xContentType);
private TestRestRequest(String path, String content, XContentType xContentType) { return builder.build();
super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ?
Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType())));
this.content = new BytesArray(content);
}
@Override
public Method method() {
return Method.GET;
}
@Override
public String uri() {
return null;
}
@Override
public boolean hasContent() {
return true;
}
@Override
public BytesReference content() {
return content;
}
} }
} }

View File

@ -27,6 +27,7 @@ import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -44,66 +45,66 @@ import static org.hamcrest.Matchers.instanceOf;
public class RestRequestTests extends ESTestCase { public class RestRequestTests extends ESTestCase {
public void testContentParser() throws IOException { public void testContentParser() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () -> Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentParser()); contentRestRequest("", emptyMap()).contentParser());
assertEquals("request body is required", e.getMessage()); assertEquals("request body is required", e.getMessage());
e = expectThrows(ElasticsearchParseException.class, () -> e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", singletonMap("source", "{}")).contentParser()); contentRestRequest("", singletonMap("source", "{}")).contentParser());
assertEquals("request body is required", e.getMessage()); assertEquals("request body is required", e.getMessage());
assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentParser().map()); assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentParser().map());
e = expectThrows(ElasticsearchParseException.class, () -> e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap(), emptyMap()).contentParser()); contentRestRequest("", emptyMap(), emptyMap()).contentParser());
assertEquals("request body is required", e.getMessage()); assertEquals("request body is required", e.getMessage());
} }
public void testApplyContentParser() throws IOException { public void testApplyContentParser() throws IOException {
new ContentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); contentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called"));
new ContentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); contentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called"));
AtomicReference<Object> source = new AtomicReference<>(); AtomicReference<Object> source = new AtomicReference<>();
new ContentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); contentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map()));
assertEquals(emptyMap(), source.get()); assertEquals(emptyMap(), source.get());
} }
public void testContentOrSourceParam() throws IOException { public void testContentOrSourceParam() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () -> Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentOrSourceParam()); contentRestRequest("", emptyMap()).contentOrSourceParam());
assertEquals("request body or source parameter is required", e.getMessage()); assertEquals("request body or source parameter is required", e.getMessage());
assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2());
assertEquals(new BytesArray("stuff"), assertEquals(new BytesArray("stuff"),
new ContentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder() contentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2()); .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2());
assertEquals(new BytesArray("{\"foo\": \"stuff\"}"), assertEquals(new BytesArray("{\"foo\": \"stuff\"}"),
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder() contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap())
.contentOrSourceParam().v2()); .contentOrSourceParam().v2());
e = expectThrows(IllegalStateException.class, () -> e = expectThrows(IllegalStateException.class, () ->
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder() contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").immutableMap()).contentOrSourceParam()); .put("source", "stuff2").immutableMap()).contentOrSourceParam());
assertEquals("source and source_content_type parameters are required", e.getMessage()); assertEquals("source and source_content_type parameters are required", e.getMessage());
} }
public void testHasContentOrSourceParam() throws IOException { public void testHasContentOrSourceParam() throws IOException {
assertEquals(false, new ContentRestRequest("", emptyMap()).hasContentOrSourceParam()); assertEquals(false, contentRestRequest("", emptyMap()).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); assertEquals(true, contentRestRequest("stuff", emptyMap()).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); assertEquals(true, contentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam());
assertEquals(true, new ContentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); assertEquals(true, contentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam());
} }
public void testContentOrSourceParamParser() throws IOException { public void testContentOrSourceParamParser() throws IOException {
Exception e = expectThrows(ElasticsearchParseException.class, () -> Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).contentOrSourceParamParser()); contentRestRequest("", emptyMap()).contentOrSourceParamParser());
assertEquals("request body or source parameter is required", e.getMessage()); assertEquals("request body or source parameter is required", e.getMessage());
assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map());
assertEquals(emptyMap(), new ContentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); assertEquals(emptyMap(), contentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map());
assertEquals(emptyMap(), new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder() assertEquals(emptyMap(), contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map()); .put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map());
} }
public void testWithContentOrSourceParamParserOrNull() throws IOException { public void testWithContentOrSourceParamParserOrNull() throws IOException {
new ContentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); contentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser));
new ContentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); contentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map()));
new ContentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> contentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser ->
assertEquals(emptyMap(), parser.map())); assertEquals(emptyMap(), parser.map()));
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder().put("source_content_type", "application/json") contentRestRequest("", MapBuilder.<String, String>newMapBuilder().put("source_content_type", "application/json")
.put("source", "{}").immutableMap()) .put("source", "{}").immutableMap())
.withContentOrSourceParamParserOrNull(parser -> .withContentOrSourceParamParserOrNull(parser ->
assertEquals(emptyMap(), parser.map())); assertEquals(emptyMap(), parser.map()));
@ -113,18 +114,18 @@ public class RestRequestTests extends ESTestCase {
for (XContentType xContentType : XContentType.values()) { for (XContentType xContentType : XContentType.values()) {
Map<String, List<String>> map = new HashMap<>(); Map<String, List<String>> map = new HashMap<>();
map.put("Content-Type", Collections.singletonList(xContentType.mediaType())); map.put("Content-Type", Collections.singletonList(xContentType.mediaType()));
ContentRestRequest restRequest = new ContentRestRequest("", Collections.emptyMap(), map); RestRequest restRequest = contentRestRequest("", Collections.emptyMap(), map);
assertEquals(xContentType, restRequest.getXContentType()); assertEquals(xContentType, restRequest.getXContentType());
map = new HashMap<>(); map = new HashMap<>();
map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters())); map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters()));
restRequest = new ContentRestRequest("", Collections.emptyMap(), map); restRequest = contentRestRequest("", Collections.emptyMap(), map);
assertEquals(xContentType, restRequest.getXContentType()); assertEquals(xContentType, restRequest.getXContentType());
} }
} }
public void testPlainTextSupport() { public void testPlainTextSupport() {
ContentRestRequest restRequest = new ContentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), RestRequest restRequest = contentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(),
Collections.singletonMap("Content-Type", Collections.singletonMap("Content-Type",
Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8")))); Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8"))));
assertNull(restRequest.getXContentType()); assertNull(restRequest.getXContentType());
@ -136,7 +137,7 @@ public class RestRequestTests extends ESTestCase {
RestRequest.ContentTypeHeaderException.class, RestRequest.ContentTypeHeaderException.class,
() -> { () -> {
final Map<String, List<String>> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type)); final Map<String, List<String>> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type));
new ContentRestRequest("", Collections.emptyMap(), headers); contentRestRequest("", Collections.emptyMap(), headers);
}); });
assertNotNull(e.getCause()); assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf(IllegalArgumentException.class)); assertThat(e.getCause(), instanceOf(IllegalArgumentException.class));
@ -144,7 +145,7 @@ public class RestRequestTests extends ESTestCase {
} }
public void testNoContentTypeHeader() { public void testNoContentTypeHeader() {
ContentRestRequest contentRestRequest = new ContentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); RestRequest contentRestRequest = contentRestRequest("", Collections.emptyMap(), Collections.emptyMap());
assertNull(contentRestRequest.getXContentType()); assertNull(contentRestRequest.getXContentType());
} }
@ -152,7 +153,7 @@ public class RestRequestTests extends ESTestCase {
List<String> headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10))); List<String> headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10)));
final RestRequest.ContentTypeHeaderException e = expectThrows( final RestRequest.ContentTypeHeaderException e = expectThrows(
RestRequest.ContentTypeHeaderException.class, RestRequest.ContentTypeHeaderException.class,
() -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); () -> contentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers)));
assertNotNull(e.getCause()); assertNotNull(e.getCause());
assertThat(e.getCause(), instanceOf((IllegalArgumentException.class))); assertThat(e.getCause(), instanceOf((IllegalArgumentException.class)));
assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided")); assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided"));
@ -160,52 +161,64 @@ public class RestRequestTests extends ESTestCase {
public void testRequiredContent() { public void testRequiredContent() {
Exception e = expectThrows(ElasticsearchParseException.class, () -> Exception e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", emptyMap()).requiredContent()); contentRestRequest("", emptyMap()).requiredContent());
assertEquals("request body is required", e.getMessage()); assertEquals("request body is required", e.getMessage());
assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).requiredContent()); assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).requiredContent());
assertEquals(new BytesArray("stuff"), assertEquals(new BytesArray("stuff"),
new ContentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder() contentRestRequest("stuff", MapBuilder.<String, String>newMapBuilder()
.put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent()); .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent());
e = expectThrows(ElasticsearchParseException.class, () -> e = expectThrows(ElasticsearchParseException.class, () ->
new ContentRestRequest("", MapBuilder.<String, String>newMapBuilder() contentRestRequest("", MapBuilder.<String, String>newMapBuilder()
.put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap())
.requiredContent()); .requiredContent());
assertEquals("request body is required", e.getMessage()); assertEquals("request body is required", e.getMessage());
e = expectThrows(IllegalStateException.class, () -> e = expectThrows(IllegalStateException.class, () ->
new ContentRestRequest("test", null, Collections.emptyMap()).requiredContent()); contentRestRequest("test", null, Collections.emptyMap()).requiredContent());
assertEquals("unknown content type", e.getMessage()); assertEquals("unknown content type", e.getMessage());
} }
private static RestRequest contentRestRequest(String content, Map<String, String> params) {
Map<String, List<String>> headers = new HashMap<>();
headers.put("Content-Type", Collections.singletonList("application/json"));
return contentRestRequest(content, params, headers);
}
private static RestRequest contentRestRequest(String content, Map<String, String> params, Map<String, List<String>> headers) {
FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY);
builder.withHeaders(headers);
builder.withContent(new BytesArray(content), null);
builder.withParams(params);
return new ContentRestRequest(builder.build());
}
private static final class ContentRestRequest extends RestRequest { private static final class ContentRestRequest extends RestRequest {
private final BytesArray content;
ContentRestRequest(String content, Map<String, String> params) { private final RestRequest restRequest;
this(content, params, Collections.singletonMap("Content-Type", Collections.singletonList("application/json")));
}
ContentRestRequest(String content, Map<String, String> params, Map<String, List<String>> headers) { private ContentRestRequest(RestRequest restRequest) {
super(NamedXContentRegistry.EMPTY, params, "not used by this test", headers); super(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(),
this.content = new BytesArray(content); restRequest.getHttpRequest(), restRequest.getHttpChannel());
} this.restRequest = restRequest;
@Override
public boolean hasContent() {
return Strings.hasLength(content);
}
@Override
public BytesReference content() {
return content;
}
@Override
public String uri() {
throw new UnsupportedOperationException("Not used by this test");
} }
@Override @Override
public Method method() { public Method method() {
throw new UnsupportedOperationException("Not used by this test"); return restRequest.method();
}
@Override
public String uri() {
return restRequest.uri();
}
@Override
public boolean hasContent() {
return Strings.hasLength(content());
}
@Override
public BytesReference content() {
return restRequest.content();
} }
} }
} }

View File

@ -30,6 +30,7 @@ import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.elasticsearch.transport.RemoteClusterConnectionTests.startTransport; import static org.elasticsearch.transport.RemoteClusterConnectionTests.startTransport;
@ -69,7 +70,6 @@ public class RemoteClusterClientTests extends ESTestCase {
} }
} }
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/29547")
public void testEnsureWeReconnect() throws Exception { public void testEnsureWeReconnect() throws Exception {
Settings remoteSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "foo_bar_cluster").build(); Settings remoteSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "foo_bar_cluster").build();
try (MockTransportService remoteTransport = startTransport("remote_node", Collections.emptyList(), Version.CURRENT, threadPool, try (MockTransportService remoteTransport = startTransport("remote_node", Collections.emptyList(), Version.CURRENT, threadPool,
@ -79,17 +79,35 @@ public class RemoteClusterClientTests extends ESTestCase {
.put(RemoteClusterService.ENABLE_REMOTE_CLUSTERS.getKey(), true) .put(RemoteClusterService.ENABLE_REMOTE_CLUSTERS.getKey(), true)
.put("search.remote.test.seeds", remoteNode.getAddress().getAddress() + ":" + remoteNode.getAddress().getPort()).build(); .put("search.remote.test.seeds", remoteNode.getAddress().getAddress() + ":" + remoteNode.getAddress().getPort()).build();
try (MockTransportService service = MockTransportService.createNewService(localSettings, Version.CURRENT, threadPool, null)) { try (MockTransportService service = MockTransportService.createNewService(localSettings, Version.CURRENT, threadPool, null)) {
Semaphore semaphore = new Semaphore(1);
service.start(); service.start();
service.addConnectionListener(new TransportConnectionListener() {
@Override
public void onNodeDisconnected(DiscoveryNode node) {
if (remoteNode.equals(node)) {
semaphore.release();
}
}
});
// this test is not perfect since we might reconnect concurrently but it will fail most of the time if we don't have
// the right calls in place in the RemoteAwareClient
service.acceptIncomingRequests(); service.acceptIncomingRequests();
for (int i = 0; i < 10; i++) {
semaphore.acquire();
try {
service.disconnectFromNode(remoteNode); service.disconnectFromNode(remoteNode);
semaphore.acquire();
RemoteClusterService remoteClusterService = service.getRemoteClusterService(); RemoteClusterService remoteClusterService = service.getRemoteClusterService();
assertBusy(() -> assertFalse(remoteClusterService.isRemoteNodeConnected("test", remoteNode)));
Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test"); Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test");
ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get(); ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get();
assertNotNull(clusterStateResponse); assertNotNull(clusterStateResponse);
assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value()); assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value());
assertTrue(remoteClusterService.isRemoteNodeConnected("test", remoteNode));
} finally {
semaphore.release();
}
}
} }
} }
} }
} }

View File

@ -19,12 +19,18 @@
package org.elasticsearch.test.rest; package org.elasticsearch.test.rest;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.net.SocketAddress; import java.net.InetSocketAddress;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -32,20 +38,33 @@ import java.util.Map;
public class FakeRestRequest extends RestRequest { public class FakeRestRequest extends RestRequest {
private final BytesReference content;
private final Method method;
private final SocketAddress remoteAddress;
public FakeRestRequest() { public FakeRestRequest() {
this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null); this(NamedXContentRegistry.EMPTY, new FakeHttpRequest(Method.GET, "", BytesArray.EMPTY, new HashMap<>()), new HashMap<>(),
new FakeHttpChannel(null));
} }
private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map<String, List<String>> headers, private FakeRestRequest(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, Map<String, String> params,
Map<String, String> params, BytesReference content, Method method, String path, SocketAddress remoteAddress) { HttpChannel httpChannel) {
super(xContentRegistry, params, path, headers); super(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel);
this.content = content; }
@Override
public boolean hasContent() {
return content() != null;
}
private static class FakeHttpRequest implements HttpRequest {
private final Method method;
private final String uri;
private final BytesReference content;
private final Map<String, List<String>> headers;
private FakeHttpRequest(Method method, String uri, BytesReference content, Map<String, List<String>> headers) {
this.method = method; this.method = method;
this.remoteAddress = remoteAddress; this.uri = uri;
this.content = content;
this.headers = headers;
} }
@Override @Override
@ -55,12 +74,7 @@ public class FakeRestRequest extends RestRequest {
@Override @Override
public String uri() { public String uri() {
return rawPath(); return uri;
}
@Override
public boolean hasContent() {
return content != null;
} }
@Override @Override
@ -69,10 +83,72 @@ public class FakeRestRequest extends RestRequest {
} }
@Override @Override
public SocketAddress getRemoteAddress() { public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Collections.emptyList();
}
@Override
public HttpVersion protocolVersion() {
return HttpVersion.HTTP_1_1;
}
@Override
public HttpRequest removeHeader(String header) {
headers.remove(header);
return this;
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
Map<String, String> headers = new HashMap<>();
return new HttpResponse() {
@Override
public void addHeader(String name, String value) {
headers.put(name, value);
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
};
}
}
private static class FakeHttpChannel implements HttpChannel {
private final InetSocketAddress remoteAddress;
private FakeHttpChannel(InetSocketAddress remoteAddress) {
this.remoteAddress = remoteAddress;
}
@Override
public void sendResponse(HttpResponse response, ActionListener<Void> listener) {
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return remoteAddress; return remoteAddress;
} }
@Override
public void close() {
}
}
public static class Builder { public static class Builder {
private final NamedXContentRegistry xContentRegistry; private final NamedXContentRegistry xContentRegistry;
@ -86,7 +162,7 @@ public class FakeRestRequest extends RestRequest {
private Method method = Method.GET; private Method method = Method.GET;
private SocketAddress address = null; private InetSocketAddress address = null;
public Builder(NamedXContentRegistry xContentRegistry) { public Builder(NamedXContentRegistry xContentRegistry) {
this.xContentRegistry = xContentRegistry; this.xContentRegistry = xContentRegistry;
@ -120,15 +196,14 @@ public class FakeRestRequest extends RestRequest {
return this; return this;
} }
public Builder withRemoteAddress(SocketAddress address) { public Builder withRemoteAddress(InetSocketAddress address) {
this.address = address; this.address = address;
return this; return this;
} }
public FakeRestRequest build() { public FakeRestRequest build() {
return new FakeRestRequest(xContentRegistry, headers, params, content, method, path, address); FakeHttpRequest fakeHttpRequest = new FakeHttpRequest(method, path, content, headers);
return new FakeRestRequest(xContentRegistry, fakeHttpRequest, params, new FakeHttpChannel(address));
} }
} }
} }

View File

@ -5,14 +5,6 @@ import org.elasticsearch.gradle.precommit.LicenseHeadersTask
Project xpackRootProject = project Project xpackRootProject = project
apply plugin: 'nebula.info-scm'
final String licenseCommit
if (version.endsWith('-SNAPSHOT')) {
licenseCommit = xpackRootProject.scminfo.change ?: "master" // leniency for non git builds
} else {
licenseCommit = "v${version}"
}
subprojects { subprojects {
group = 'org.elasticsearch.plugin' group = 'org.elasticsearch.plugin'
ext.xpackRootProject = xpackRootProject ext.xpackRootProject = xpackRootProject
@ -21,7 +13,7 @@ subprojects {
ext.xpackModule = { String moduleName -> xpackProject("plugin:${moduleName}").path } ext.xpackModule = { String moduleName -> xpackProject("plugin:${moduleName}").path }
ext.licenseName = 'Elastic License' ext.licenseName = 'Elastic License'
ext.licenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt" ext.licenseUrl = ext.elasticLicenseUrl
project.ext.licenseFile = rootProject.file('licenses/ELASTIC-LICENSE.txt') project.ext.licenseFile = rootProject.file('licenses/ELASTIC-LICENSE.txt')
project.ext.noticeFile = xpackRootProject.file('NOTICE.txt') project.ext.noticeFile = xpackRootProject.file('NOTICE.txt')

View File

@ -43,9 +43,7 @@ subprojects {
final FileCollection classDirectories = project.files(files).filter { it.exists() } final FileCollection classDirectories = project.files(files).filter { it.exists() }
doFirst { doFirst {
String cp = project.configurations.featureAwarePlugin.asPath args('-cp', project.configurations.featureAwarePlugin.asPath, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck')
cp = cp.replaceAll(":[^:]*/asm-debug-all-5.1.jar:", ":")
args('-cp', cp, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck')
classDirectories.each { args it.getAbsolutePath() } classDirectories.each { args it.getAbsolutePath() }
} }
doLast { doLast {

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.core.ml.job.config; package org.elasticsearch.xpack.core.ml.job.config;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
@ -30,6 +31,7 @@ public class MlFilter implements ToXContentObject, Writeable {
public static final ParseField TYPE = new ParseField("type"); public static final ParseField TYPE = new ParseField("type");
public static final ParseField ID = new ParseField("filter_id"); public static final ParseField ID = new ParseField("filter_id");
public static final ParseField DESCRIPTION = new ParseField("description");
public static final ParseField ITEMS = new ParseField("items"); public static final ParseField ITEMS = new ParseField("items");
// For QueryPage // For QueryPage
@ -43,27 +45,38 @@ public class MlFilter implements ToXContentObject, Writeable {
parser.declareString((builder, s) -> {}, TYPE); parser.declareString((builder, s) -> {}, TYPE);
parser.declareString(Builder::setId, ID); parser.declareString(Builder::setId, ID);
parser.declareStringOrNull(Builder::setDescription, DESCRIPTION);
parser.declareStringArray(Builder::setItems, ITEMS); parser.declareStringArray(Builder::setItems, ITEMS);
return parser; return parser;
} }
private final String id; private final String id;
private final String description;
private final List<String> items; private final List<String> items;
public MlFilter(String id, List<String> items) { public MlFilter(String id, String description, List<String> items) {
this.id = Objects.requireNonNull(id, ID.getPreferredName() + " must not be null"); this.id = Objects.requireNonNull(id, ID.getPreferredName() + " must not be null");
this.description = description;
this.items = Objects.requireNonNull(items, ITEMS.getPreferredName() + " must not be null"); this.items = Objects.requireNonNull(items, ITEMS.getPreferredName() + " must not be null");
} }
public MlFilter(StreamInput in) throws IOException { public MlFilter(StreamInput in) throws IOException {
id = in.readString(); id = in.readString();
if (in.getVersion().onOrAfter(Version.V_6_4_0)) {
description = in.readOptionalString();
} else {
description = null;
}
items = Arrays.asList(in.readStringArray()); items = Arrays.asList(in.readStringArray());
} }
@Override @Override
public void writeTo(StreamOutput out) throws IOException { public void writeTo(StreamOutput out) throws IOException {
out.writeString(id); out.writeString(id);
if (out.getVersion().onOrAfter(Version.V_6_4_0)) {
out.writeOptionalString(description);
}
out.writeStringArray(items.toArray(new String[items.size()])); out.writeStringArray(items.toArray(new String[items.size()]));
} }
@ -71,6 +84,9 @@ public class MlFilter implements ToXContentObject, Writeable {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(); builder.startObject();
builder.field(ID.getPreferredName(), id); builder.field(ID.getPreferredName(), id);
if (description != null) {
builder.field(DESCRIPTION.getPreferredName(), description);
}
builder.field(ITEMS.getPreferredName(), items); builder.field(ITEMS.getPreferredName(), items);
if (params.paramAsBoolean(MlMetaIndex.INCLUDE_TYPE_KEY, false)) { if (params.paramAsBoolean(MlMetaIndex.INCLUDE_TYPE_KEY, false)) {
builder.field(TYPE.getPreferredName(), FILTER_TYPE); builder.field(TYPE.getPreferredName(), FILTER_TYPE);
@ -83,6 +99,10 @@ public class MlFilter implements ToXContentObject, Writeable {
return id; return id;
} }
public String getDescription() {
return description;
}
public List<String> getItems() { public List<String> getItems() {
return new ArrayList<>(items); return new ArrayList<>(items);
} }
@ -98,12 +118,12 @@ public class MlFilter implements ToXContentObject, Writeable {
} }
MlFilter other = (MlFilter) obj; MlFilter other = (MlFilter) obj;
return id.equals(other.id) && items.equals(other.items); return id.equals(other.id) && Objects.equals(description, other.description) && items.equals(other.items);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(id, items); return Objects.hash(id, description, items);
} }
public String documentId() { public String documentId() {
@ -114,30 +134,45 @@ public class MlFilter implements ToXContentObject, Writeable {
return DOCUMENT_ID_PREFIX + filterId; return DOCUMENT_ID_PREFIX + filterId;
} }
public static Builder builder(String filterId) {
return new Builder().setId(filterId);
}
public static class Builder { public static class Builder {
private String id; private String id;
private String description;
private List<String> items = Collections.emptyList(); private List<String> items = Collections.emptyList();
private Builder() {}
public Builder setId(String id) { public Builder setId(String id) {
this.id = id; this.id = id;
return this; return this;
} }
private Builder() {}
@Nullable @Nullable
public String getId() { public String getId() {
return id; return id;
} }
public Builder setDescription(String description) {
this.description = description;
return this;
}
public Builder setItems(List<String> items) { public Builder setItems(List<String> items) {
this.items = items; this.items = items;
return this; return this;
} }
public Builder setItems(String... items) {
this.items = Arrays.asList(items);
return this;
}
public MlFilter build() { public MlFilter build() {
return new MlFilter(id, items); return new MlFilter(id, description, items);
} }
} }
} }

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.core.security.rest; package org.elasticsearch.xpack.core.security.rest;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
@ -17,7 +16,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -33,37 +31,15 @@ public interface RestRequestFilter {
default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException { default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException {
Set<String> fields = getFilteredFields(); Set<String> fields = getFilteredFields();
if (restRequest.hasContent() && fields.isEmpty() == false) { if (restRequest.hasContent() && fields.isEmpty() == false) {
return new RestRequest(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders()) { return new RestRequest(restRequest) {
private BytesReference filteredBytes = null; private BytesReference filteredBytes = null;
@Override
public Method method() {
return restRequest.method();
}
@Override
public String uri() {
return restRequest.uri();
}
@Override @Override
public boolean hasContent() { public boolean hasContent() {
return true; return true;
} }
@Nullable
@Override
public SocketAddress getRemoteAddress() {
return restRequest.getRemoteAddress();
}
@Nullable
@Override
public SocketAddress getLocalAddress() {
return restRequest.getLocalAddress();
}
@Override @Override
public BytesReference content() { public BytesReference content() {
if (filteredBytes == null) { if (filteredBytes == null) {

View File

@ -9,6 +9,7 @@ import org.elasticsearch.test.AbstractStreamableTestCase;
import org.elasticsearch.xpack.core.ml.action.GetFiltersAction.Response; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction.Response;
import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.job.config.MlFilter; import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests;
import java.util.Collections; import java.util.Collections;
@ -17,9 +18,7 @@ public class GetFiltersActionResponseTests extends AbstractStreamableTestCase<Ge
@Override @Override
protected Response createTestInstance() { protected Response createTestInstance() {
final QueryPage<MlFilter> result; final QueryPage<MlFilter> result;
MlFilter doc = MlFilterTests.createRandom();
MlFilter doc = new MlFilter(
randomAlphaOfLengthBetween(1, 20), Collections.singletonList(randomAlphaOfLengthBetween(1, 20)));
result = new QueryPage<>(Collections.singletonList(doc), 1, MlFilter.RESULTS_FIELD); result = new QueryPage<>(Collections.singletonList(doc), 1, MlFilter.RESULTS_FIELD);
return new Response(result); return new Response(result);
} }

View File

@ -8,10 +8,7 @@ package org.elasticsearch.xpack.core.ml.action;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.test.AbstractStreamableXContentTestCase; import org.elasticsearch.test.AbstractStreamableXContentTestCase;
import org.elasticsearch.xpack.core.ml.action.PutFilterAction.Request; import org.elasticsearch.xpack.core.ml.action.PutFilterAction.Request;
import org.elasticsearch.xpack.core.ml.job.config.MlFilter; import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests;
import java.util.ArrayList;
import java.util.List;
public class PutFilterActionRequestTests extends AbstractStreamableXContentTestCase<Request> { public class PutFilterActionRequestTests extends AbstractStreamableXContentTestCase<Request> {
@ -19,13 +16,7 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC
@Override @Override
protected Request createTestInstance() { protected Request createTestInstance() {
int size = randomInt(10); return new PutFilterAction.Request(MlFilterTests.createRandom(filterId));
List<String> items = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
items.add(randomAlphaOfLengthBetween(1, 20));
}
MlFilter filter = new MlFilter(filterId, items);
return new PutFilterAction.Request(filter);
} }
@Override @Override
@ -42,5 +33,4 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC
protected Request doParseInstance(XContentParser parser) { protected Request doParseInstance(XContentParser parser) {
return PutFilterAction.Request.parseRequest(filterId, parser); return PutFilterAction.Request.parseRequest(filterId, parser);
} }
} }

View File

@ -26,12 +26,25 @@ public class MlFilterTests extends AbstractSerializingTestCase<MlFilter> {
@Override @Override
protected MlFilter createTestInstance() { protected MlFilter createTestInstance() {
return createRandom();
}
public static MlFilter createRandom() {
return createRandom(randomAlphaOfLengthBetween(1, 20));
}
public static MlFilter createRandom(String filterId) {
String description = null;
if (randomBoolean()) {
description = randomAlphaOfLength(20);
}
int size = randomInt(10); int size = randomInt(10);
List<String> items = new ArrayList<>(size); List<String> items = new ArrayList<>(size);
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
items.add(randomAlphaOfLengthBetween(1, 20)); items.add(randomAlphaOfLengthBetween(1, 20));
} }
return new MlFilter(randomAlphaOfLengthBetween(1, 20), items); return new MlFilter(filterId, description, items);
} }
@Override @Override
@ -45,13 +58,13 @@ public class MlFilterTests extends AbstractSerializingTestCase<MlFilter> {
} }
public void testNullId() { public void testNullId() {
NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, Collections.emptyList())); NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, "", Collections.emptyList()));
assertEquals(MlFilter.ID.getPreferredName() + " must not be null", ex.getMessage()); assertEquals(MlFilter.ID.getPreferredName() + " must not be null", ex.getMessage());
} }
public void testNullItems() { public void testNullItems() {
NullPointerException ex = NullPointerException ex =
expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), null)); expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), "", null));
assertEquals(MlFilter.ITEMS.getPreferredName() + " must not be null", ex.getMessage()); assertEquals(MlFilter.ITEMS.getPreferredName() + " must not be null", ex.getMessage());
} }

View File

@ -385,8 +385,8 @@ public class JobProviderIT extends MlSingleNodeTestCase {
indexScheduledEvents(events); indexScheduledEvents(events);
List<MlFilter> filters = new ArrayList<>(); List<MlFilter> filters = new ArrayList<>();
filters.add(new MlFilter("fruit", Arrays.asList("apple", "pear"))); filters.add(MlFilter.builder("fruit").setItems("apple", "pear").build());
filters.add(new MlFilter("tea", Arrays.asList("green", "builders"))); filters.add(MlFilter.builder("tea").setItems("green", "builders").build());
indexFilters(filters); indexFilters(filters);
DataCounts earliestCounts = DataCountsTests.createTestInstance(jobId); DataCounts earliestCounts = DataCountsTests.createTestInstance(jobId);

View File

@ -210,7 +210,7 @@ public class JobManagerTests extends ESTestCase {
JobManager jobManager = createJobManager(); JobManager jobManager = createJobManager();
MlFilter filter = new MlFilter("foo_filter", Arrays.asList("a", "b")); MlFilter filter = MlFilter.builder("foo_filter").setItems("a", "b").build();
jobManager.updateProcessOnFilterChanged(filter); jobManager.updateProcessOnFilterChanged(filter);

View File

@ -207,8 +207,8 @@ public class ControlMsgToProcessWriterTests extends ESTestCase {
public void testWriteUpdateFiltersMessage() throws IOException { public void testWriteUpdateFiltersMessage() throws IOException {
ControlMsgToProcessWriter writer = new ControlMsgToProcessWriter(lengthEncodedWriter, 2); ControlMsgToProcessWriter writer = new ControlMsgToProcessWriter(lengthEncodedWriter, 2);
MlFilter filter1 = new MlFilter("filter_1", Arrays.asList("a")); MlFilter filter1 = MlFilter.builder("filter_1").setItems("a").build();
MlFilter filter2 = new MlFilter("filter_2", Arrays.asList("b", "c")); MlFilter filter2 = MlFilter.builder("filter_2").setItems("b", "c").build();
writer.writeUpdateFiltersMessage(Arrays.asList(filter1, filter2)); writer.writeUpdateFiltersMessage(Arrays.asList(filter1, filter2));

View File

@ -220,8 +220,8 @@ public class FieldConfigWriterTests extends ESTestCase {
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(d)); AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(d));
analysisConfig = builder.build(); analysisConfig = builder.build();
filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); filters.add(MlFilter.builder("filter_1").setItems("a", "b").build());
filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); filters.add(MlFilter.builder("filter_2").setItems("c", "d").build());
writer = mock(OutputStreamWriter.class); writer = mock(OutputStreamWriter.class);
createFieldConfigWriter().write(); createFieldConfigWriter().write();

View File

@ -10,7 +10,6 @@ import org.elasticsearch.xpack.core.ml.job.config.MlFilter;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -28,8 +27,8 @@ public class MlFilterWriterTests extends ESTestCase {
public void testWrite() throws IOException { public void testWrite() throws IOException {
List<MlFilter> filters = new ArrayList<>(); List<MlFilter> filters = new ArrayList<>();
filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); filters.add(MlFilter.builder("filter_1").setItems("a", "b").build());
filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); filters.add(MlFilter.builder("filter_2").setItems("c", "d").build());
StringBuilder buffer = new StringBuilder(); StringBuilder buffer = new StringBuilder();
new MlFilterWriter(filters, buffer).write(); new MlFilterWriter(filters, buffer).write();

View File

@ -69,7 +69,6 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.UnknownHostException; import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.ArrayList; import java.util.ArrayList;
@ -829,10 +828,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); msg.builder.field(Field.REQUEST_BODY, restRequestContent(request));
} }
msg.builder.field(Field.ORIGIN_TYPE, "rest"); msg.builder.field(Field.ORIGIN_TYPE, "rest");
SocketAddress address = request.getRemoteAddress(); InetSocketAddress address = request.getHttpChannel().getRemoteAddress();
if (address instanceof InetSocketAddress) { if (address != null) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress()));
.getAddress()));
} else { } else {
msg.builder.field(Field.ORIGIN_ADDRESS, address); msg.builder.field(Field.ORIGIN_ADDRESS, address);
} }
@ -854,10 +852,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); msg.builder.field(Field.REQUEST_BODY, restRequestContent(request));
} }
msg.builder.field(Field.ORIGIN_TYPE, "rest"); msg.builder.field(Field.ORIGIN_TYPE, "rest");
SocketAddress address = request.getRemoteAddress(); InetSocketAddress address = request.getHttpChannel().getRemoteAddress();
if (address instanceof InetSocketAddress) { if (address != null) {
msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress()));
.getAddress()));
} else { } else {
msg.builder.field(Field.ORIGIN_ADDRESS, address); msg.builder.field(Field.ORIGIN_ADDRESS, address);
} }

View File

@ -38,7 +38,6 @@ import org.elasticsearch.xpack.security.transport.filter.SecurityIpFilterRule;
import java.net.InetAddress; import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet; import java.util.EnumSet;
@ -544,13 +543,8 @@ public class LoggingAuditTrail extends AbstractComponent implements AuditTrail,
} }
private static String hostAttributes(RestRequest request) { private static String hostAttributes(RestRequest request) {
String formattedAddress; final InetSocketAddress socketAddress = request.getHttpChannel().getRemoteAddress();
final SocketAddress socketAddress = request.getRemoteAddress(); String formattedAddress = NetworkAddress.format(socketAddress.getAddress());
if (socketAddress instanceof InetSocketAddress) {
formattedAddress = NetworkAddress.format(((InetSocketAddress) socketAddress).getAddress());
} else {
formattedAddress = socketAddress.toString();
}
return "origin_address=[" + formattedAddress + "]"; return "origin_address=[" + formattedAddress + "]";
} }

View File

@ -20,7 +20,7 @@ public class RemoteHostHeader {
* then be copied to the subsequent action requests. * then be copied to the subsequent action requests.
*/ */
public static void process(RestRequest request, ThreadContext threadContext) { public static void process(RestRequest request, ThreadContext threadContext) {
threadContext.putTransient(KEY, request.getRemoteAddress()); threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress());
} }
/** /**

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.xpack.security.rest; package org.elasticsearch.xpack.security.rest;
import io.netty.channel.Channel;
import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandler;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.message.ParameterizedMessage;
@ -13,7 +14,8 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.logging.ESLoggerFactory; import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.netty4.Netty4HttpRequest; import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.netty4.Netty4HttpChannel;
import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestChannel;
@ -50,10 +52,11 @@ public class SecurityRestFilter implements RestHandler {
if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) { if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) {
// CORS - allow for preflight unauthenticated OPTIONS request // CORS - allow for preflight unauthenticated OPTIONS request
if (extractClientCertificate) { if (extractClientCertificate) {
Netty4HttpRequest nettyHttpRequest = (Netty4HttpRequest) request; HttpChannel httpChannel = request.getHttpChannel();
SslHandler handler = nettyHttpRequest.getChannel().pipeline().get(SslHandler.class); Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel();
SslHandler handler = nettyChannel.pipeline().get(SslHandler.class);
assert handler != null; assert handler != null;
ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel()); ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyChannel);
} }
service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap( service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap(
authentication -> { authentication -> {

View File

@ -104,7 +104,7 @@ public class SecurityNetty4HttpServerTransport extends Netty4HttpServerTransport
private final class HttpSslChannelHandler extends HttpChannelHandler { private final class HttpSslChannelHandler extends HttpChannelHandler {
HttpSslChannelHandler() { HttpSslChannelHandler() {
super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext()); super(SecurityNetty4HttpServerTransport.this, handlingSettings);
} }
@Override @Override

View File

@ -33,6 +33,7 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.plugins.MetaDataUpgrader; import org.elasticsearch.plugins.MetaDataUpgrader;
import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
@ -914,7 +915,9 @@ public class IndexAuditTrailTests extends SecurityIntegTestCase {
private RestRequest mockRestRequest() { private RestRequest mockRestRequest() {
RestRequest request = mock(RestRequest.class); RestRequest request = mock(RestRequest.class);
when(request.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); HttpChannel httpChannel = mock(HttpChannel.class);
when(request.getHttpChannel()).thenReturn(httpChannel);
when(httpChannel.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200));
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
return request; return request;
} }

View File

@ -88,6 +88,6 @@ public class RestRequestFilterTests extends ESTestCase {
new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON) new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON)
.withRemoteAddress(address).build(); .withRemoteAddress(address).build();
RestRequest filtered = filter.getFilteredRequest(restRequest); RestRequest filtered = filter.getFilteredRequest(restRequest);
assertEquals(address, filtered.getRemoteAddress()); assertEquals(address, filtered.getHttpChannel().getRemoteAddress());
} }
} }

View File

@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestChannel;
@ -67,6 +68,7 @@ public class SecurityRestFilterTests extends ESTestCase {
public void testProcess() throws Exception { public void testProcess() throws Exception {
RestRequest request = mock(RestRequest.class); RestRequest request = mock(RestRequest.class);
when(request.getHttpChannel()).thenReturn(mock(HttpChannel.class));
Authentication authentication = mock(Authentication.class); Authentication authentication = mock(Authentication.class);
doAnswer((i) -> { doAnswer((i) -> {
ActionListener callback = ActionListener callback =

View File

@ -20,7 +20,10 @@ integTest.enabled = false
dependencies { dependencies {
compileOnly "org.elasticsearch.plugin:x-pack-core:${version}" compileOnly "org.elasticsearch.plugin:x-pack-core:${version}"
compileOnly project(':modules:lang-painless') compileOnly(project(':modules:lang-painless')) {
// exclude ASM to not affect featureAware task on Java 10+
exclude group: "org.ow2.asm"
}
compile project('sql-proto') compile project('sql-proto')
compile "org.elasticsearch.plugin:aggs-matrix-stats-client:${version}" compile "org.elasticsearch.plugin:aggs-matrix-stats-client:${version}"
compile "org.antlr:antlr4-runtime:4.5.3" compile "org.antlr:antlr4-runtime:4.5.3"

View File

@ -32,6 +32,7 @@ setup:
filter_id: filter-foo2 filter_id: filter-foo2
body: > body: >
{ {
"description": "This filter has a description",
"items": ["123", "lmnop"] "items": ["123", "lmnop"]
} }
@ -76,6 +77,7 @@ setup:
- match: - match:
filters.1: filters.1:
filter_id: "filter-foo2" filter_id: "filter-foo2"
description: "This filter has a description"
items: ["123", "lmnop"] items: ["123", "lmnop"]
- do: - do:

View File

@ -120,7 +120,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
} }
public void testScope() throws Exception { public void testScope() throws Exception {
MlFilter safeIps = new MlFilter("safe_ips", Arrays.asList("111.111.111.111", "222.222.222.222")); MlFilter safeIps = MlFilter.builder("safe_ips").setItems("111.111.111.111", "222.222.222.222").build();
assertThat(putMlFilter(safeIps), is(true)); assertThat(putMlFilter(safeIps), is(true));
DetectionRule rule = new DetectionRule.Builder(RuleScope.builder().include("ip", "safe_ips")).build(); DetectionRule rule = new DetectionRule.Builder(RuleScope.builder().include("ip", "safe_ips")).build();
@ -178,7 +178,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
assertThat(records.get(0).getOverFieldValue(), equalTo("333.333.333.333")); assertThat(records.get(0).getOverFieldValue(), equalTo("333.333.333.333"));
// Now let's update the filter // Now let's update the filter
MlFilter updatedFilter = new MlFilter(safeIps.getId(), Collections.singletonList("333.333.333.333")); MlFilter updatedFilter = MlFilter.builder(safeIps.getId()).setItems("333.333.333.333").build();
assertThat(putMlFilter(updatedFilter), is(true)); assertThat(putMlFilter(updatedFilter), is(true));
// Wait until the notification that the process was updated is indexed // Wait until the notification that the process was updated is indexed
@ -229,7 +229,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase {
public void testScopeAndCondition() throws IOException { public void testScopeAndCondition() throws IOException {
// We have 2 IPs and they're both safe-listed. // We have 2 IPs and they're both safe-listed.
List<String> ips = Arrays.asList("111.111.111.111", "222.222.222.222"); List<String> ips = Arrays.asList("111.111.111.111", "222.222.222.222");
MlFilter safeIps = new MlFilter("safe_ips", ips); MlFilter safeIps = MlFilter.builder("safe_ips").setItems(ips).build();
assertThat(putMlFilter(safeIps), is(true)); assertThat(putMlFilter(safeIps), is(true));
// Ignore if ip in safe list AND actual < 10. // Ignore if ip in safe list AND actual < 10.