diff --git a/build.gradle b/build.gradle index ec81047e3e6..9bb08cf29db 100644 --- a/build.gradle +++ b/build.gradle @@ -53,9 +53,23 @@ subprojects { description = "Elasticsearch subproject ${project.path}" } +apply plugin: 'nebula.info-scm' +String licenseCommit +if (VersionProperties.elasticsearch.toString().endsWith('-SNAPSHOT')) { + licenseCommit = scminfo.change ?: "master" // leniency for non git builds +} else { + licenseCommit = "v${version}" +} +String elasticLicenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt" + subprojects { + // Default to the apache license project.ext.licenseName = 'The Apache Software License, Version 2.0' project.ext.licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt' + + // But stick the Elastic license url in project.ext so we can get it if we need to switch to it + project.ext.elasticLicenseUrl = elasticLicenseUrl + // we only use maven publish to add tasks for pom generation plugins.withType(MavenPublishPlugin).whenPluginAdded { publishing { diff --git a/distribution/archives/build.gradle b/distribution/archives/build.gradle index c1097b68b89..71606c2c027 100644 --- a/distribution/archives/build.gradle +++ b/distribution/archives/build.gradle @@ -228,6 +228,8 @@ subprojects { check.dependsOn checkNotice if (project.name == 'zip' || project.name == 'tar') { + project.ext.licenseName = 'Elastic License' + project.ext.licenseUrl = ext.elasticLicenseUrl task checkMlCppNotice { dependsOn buildDist, checkExtraction onlyIf toolExists diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java index cb31d444544..473985d2109 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java @@ -19,252 +19,58 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.net.InetSocketAddress; -final class Netty4HttpChannel extends AbstractRestChannel { +public class Netty4HttpChannel implements HttpChannel { - private final Netty4HttpServerTransport transport; private final Channel channel; - private final FullHttpRequest nettyRequest; - private final int sequence; - private final ThreadContext threadContext; - private final HttpHandlingSettings handlingSettings; - /** - * @param transport The corresponding NettyHttpServerTransport where this channel belongs to. - * @param request The request that is handled by this channel. - * @param sequence The pipelining sequence number for this request - * @param handlingSettings true if error messages should include stack traces. - * @param threadContext the thread context for the channel - */ - Netty4HttpChannel(Netty4HttpServerTransport transport, Netty4HttpRequest request, int sequence, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { - super(request, handlingSettings.getDetailedErrorsEnabled()); - this.transport = transport; - this.channel = request.getChannel(); - this.nettyRequest = request.request(); - this.sequence = sequence; - this.threadContext = threadContext; - this.handlingSettings = handlingSettings; + Netty4HttpChannel(Channel channel) { + this.channel = channel; } @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(transport.bigArrays); + public void sendResponse(HttpResponse response, ActionListener listener) { + ChannelPromise writePromise = channel.newPromise(); + writePromise.addListener(f -> { + if (f.isSuccess()) { + listener.onResponse(null); + } else { + final Throwable cause = f.cause(); + Netty4Utils.maybeDie(cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); + } + } + }); + channel.writeAndFlush(response, writePromise); } @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = Netty4Utils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - Netty4CorsHandler.setCorsResponseHeaders(nettyRequest, resp, transport.getCorsConfig()); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - BytesReference content = response.content(); - boolean releaseContent = content instanceof Releasable; - boolean releaseBytesStreamOutput = bytesOutputOrNull() instanceof ReleasableBytesStreamOutput; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - final ChannelPromise promise = channel.newPromise(); - - if (releaseContent) { - promise.addListener(f -> ((Releasable) content).close()); - } - - if (releaseBytesStreamOutput) { - promise.addListener(f -> bytesOutputOrNull().close()); - } - - if (isCloseConnection()) { - promise.addListener(ChannelFutureListener.CLOSE); - } - - Netty4HttpResponse newResponse = new Netty4HttpResponse(sequence, resp); - - channel.writeAndFlush(newResponse, promise); - releaseContent = false; - releaseBytesStreamOutput = false; - } finally { - if (releaseContent) { - ((Releasable) content).close(); - } - if (releaseBytesStreamOutput) { - bytesOutputOrNull().close(); - } - } + public InetSocketAddress getLocalAddress() { + return (InetSocketAddress) channel.localAddress(); } - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); + @Override + public InetSocketAddress getRemoteAddress() { + return (InetSocketAddress) channel.remoteAddress(); } - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } + @Override + public void close() { + channel.close(); } - private void addCookies(HttpResponse resp) { - if (handlingSettings.isResetCookies()) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); - } - } - } - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); - } - - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; - } - - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); - } - - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public Channel getNettyChannel() { + return channel; } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 12c2e9a6857..e6436ccea1a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -66,7 +66,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { try { List> readyResponses = aggregator.write(response, promise); for (Tuple readyResponse : readyResponses) { - ctx.write(readyResponse.v1().getResponse(), readyResponse.v2()); + ctx.write(readyResponse.v1(), readyResponse.v2()); } success = true; } catch (IllegalStateException e) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 2ce6ffada67..ffabe5cbbe2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -19,17 +19,22 @@ package org.elasticsearch.http.netty4; -import io.netty.channel.Channel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.net.SocketAddress; import java.util.AbstractMap; import java.util.Collection; import java.util.Collections; @@ -38,25 +43,16 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -public class Netty4HttpRequest extends RestRequest { - +public class Netty4HttpRequest implements HttpRequest { private final FullHttpRequest request; - private final Channel channel; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - /** - * Construct a new request. - * - * @param xContentRegistry the content registry - * @param request the underlying request - * @param channel the channel for the request - * @throws BadParameterException if the parameters can not be decoded - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); + Netty4HttpRequest(FullHttpRequest request, int sequence) { this.request = request; - this.channel = channel; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = Netty4Utils.toBytesReference(request.content()); } else { @@ -64,71 +60,39 @@ public class Netty4HttpRequest extends RestRequest { } } - /** - * Construct a new request. In contrast to - * {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so - * this constructor will not throw a {@link BadParameterException}. - * - * @param xContentRegistry the content registry - * @param params the parameters for the request - * @param uri the path for the request - * @param request the underlying request - * @param channel the channel for the request - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String uri, - final FullHttpRequest request, - final Channel channel) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); - this.request = request; - this.channel = channel; - if (request.content().isReadable()) { - this.content = Netty4Utils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - } - - public FullHttpRequest request() { - return this.request; - } - @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -139,40 +103,64 @@ public class Netty4HttpRequest extends RestRequest { return request.uri(); } - @Override - public boolean hasContent() { - return content.length() > 0; - } - @Override public BytesReference content() { return content; } - /** - * Returns the remote address where this rest request channel is "connected to". The - * returned {@link SocketAddress} is supposed to be down-cast into more - * concrete type such as {@link java.net.InetSocketAddress} to retrieve - * the detailed information. - */ + @Override - public SocketAddress getRemoteAddress() { - return channel.remoteAddress(); + public final Map> getHeaders() { + return headers; } - /** - * Returns the local address where this request channel is bound to. The returned - * {@link SocketAddress} is supposed to be down-cast into more concrete - * type such as {@link java.net.InetSocketAddress} to retrieve the detailed - * information. - */ @Override - public SocketAddress getLocalAddress() { - return channel.localAddress(); + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); } - public Channel getChannel() { - return channel; + @Override + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } + } + + @Override + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new Netty4HttpRequest(requestWithoutHeader, sequence); + } + + @Override + public Netty4HttpResponse createResponse(RestStatus status, BytesReference content) { + return new Netty4HttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { + return request; + } + + int sequence() { + return sequence; } /** @@ -249,7 +237,7 @@ public class Netty4HttpRequest extends RestRequest { @Override public Set>> entrySet() { return httpHeaders.names().stream().map(k -> new AbstractMap.SimpleImmutableEntry<>(k, httpHeaders.getAll(k))) - .collect(Collectors.toSet()); + .collect(Collectors.toSet()); } } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java index c3a010226a4..4547a63a9a2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java @@ -20,112 +20,51 @@ package org.elasticsearch.http.netty4; import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaders; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpPipelinedRequest; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; - @ChannelHandler.Sharable class Netty4HttpRequestHandler extends SimpleChannelInboundHandler> { private final Netty4HttpServerTransport serverTransport; - private final HttpHandlingSettings handlingSettings; - private final ThreadContext threadContext; - Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { + Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport) { this.serverTransport = serverTransport; - this.handlingSettings = handlingSettings; - this.threadContext = threadContext; } @Override protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest msg) throws Exception { - final FullHttpRequest request = msg.getRequest(); + Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); + FullHttpRequest request = msg.getRequest(); try { + final FullHttpRequest copiedRequest = + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + Unpooled.copiedBuffer(request.content()), + request.headers(), + request.trailingHeaders()); - final FullHttpRequest copy = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - Unpooled.copiedBuffer(request.content()), - request.headers(), - request.trailingHeaders()); - - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final Netty4HttpRequest httpRequest; - { - Netty4HttpRequest innerHttpRequest; - try { - innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel()); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copy, ctx.channel()); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these - * parameter values. - */ - final Netty4HttpChannel channel; - { - Netty4HttpChannel innerChannel; - try { - innerChannel = - new Netty4HttpChannel(serverTransport, httpRequest, msg.getSequence(), handlingSettings, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final Netty4HttpRequest innerRequest = - new Netty4HttpRequest( - serverTransport.xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copy.uri(), - copy, - ctx.channel()); - innerChannel = - new Netty4HttpChannel(serverTransport, innerRequest, msg.getSequence(), handlingSettings, threadContext); - } - channel = innerChannel; - } + Netty4HttpRequest httpRequest = new Netty4HttpRequest(copiedRequest, msg.getSequence()); if (request.decoderResult().isFailure()) { - serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + serverTransport.incomingRequestError(httpRequest, channel, new Exception(cause)); + } else { + serverTransport.incomingRequestError(httpRequest, channel, (Exception) cause); + } } else { - serverTransport.dispatchRequest(httpRequest, channel); + serverTransport.incomingRequest(httpRequest, channel); } } finally { // As we have copied the buffer, we can release the request @@ -133,32 +72,6 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler MAP; + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + } + } + diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 0e18232e01c..6bfd8168dbe 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -39,6 +39,7 @@ import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.util.AttributeKey; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.common.Strings; @@ -53,9 +54,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.HttpHandlingSettings; @@ -149,38 +148,29 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public static final Setting SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE = Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); - protected final BigArrays bigArrays; + private final ByteSizeValue maxInitialLineLength; + private final ByteSizeValue maxHeaderSize; + private final ByteSizeValue maxChunkSize; - protected final ByteSizeValue maxInitialLineLength; - protected final ByteSizeValue maxHeaderSize; - protected final ByteSizeValue maxChunkSize; + private final int workerCount; - protected final int workerCount; + private final int pipeliningMaxEvents; - protected final int pipeliningMaxEvents; + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean reuseAddress; - /** - * The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}. - */ - protected final NamedXContentRegistry xContentRegistry; - - protected final boolean tcpNoDelay; - protected final boolean tcpKeepAlive; - protected final boolean reuseAddress; - - protected final ByteSizeValue tcpSendBufferSize; - protected final ByteSizeValue tcpReceiveBufferSize; - protected final RecvByteBufAllocator recvByteBufAllocator; + private final ByteSizeValue tcpSendBufferSize; + private final ByteSizeValue tcpReceiveBufferSize; + private final RecvByteBufAllocator recvByteBufAllocator; private final int readTimeoutMillis; - protected final int maxCompositeBufferComponents; + private final int maxCompositeBufferComponents; protected volatile ServerBootstrap serverBootstrap; protected final List serverChannels = new ArrayList<>(); - protected final HttpHandlingSettings httpHandlingSettings; - // package private for testing Netty4OpenChannelsHandler serverOpenChannels; @@ -189,16 +179,13 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); - this.bigArrays = bigArrays; - this.xContentRegistry = xContentRegistry; this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings); this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings); @@ -398,26 +385,27 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { } public ChannelHandler configureServerChannelHandler() { - return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext()); + return new HttpChannelHandler(this, handlingSettings); } + static final AttributeKey HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel"); + protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; private final Netty4HttpRequestHandler requestHandler; private final HttpHandlingSettings handlingSettings; - protected HttpChannelHandler( - final Netty4HttpServerTransport transport, - final HttpHandlingSettings handlingSettings, - final ThreadContext threadContext) { + protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) { this.transport = transport; this.handlingSettings = handlingSettings; - this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext); + this.requestHandler = new Netty4HttpRequestHandler(transport); } @Override protected void initChannel(Channel ch) throws Exception { + Netty4HttpChannel nettyTcpChannel = new Netty4HttpChannel(ch); + ch.attr(HTTP_CHANNEL_KEY).set(nettyTcpChannel); ch.pipeline().addLast("openChannels", transport.serverOpenChannels); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); final HttpRequestDecoder decoder = new HttpRequestDecoder( diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java index 779eb4fe2e4..38d832d6080 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java @@ -22,6 +22,7 @@ package org.elasticsearch.http.netty4.cors; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.netty4.Netty4HttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public class Netty4CorsHandler extends ChannelDuplexHandler { ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass(); + Netty4HttpResponse response = (Netty4HttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise);; + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index f4818a2e567..466c4b68bfa 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -333,10 +333,10 @@ public class Netty4Transport extends TcpTransport { addClosedExceptionLogger(ch); NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name); ch.attr(CHANNEL_KEY).set(nettyTcpChannel); - serverAcceptedChannel(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name)); + serverAcceptedChannel(nettyTcpChannel); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java index f650e757e7a..89fabdcd763 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java @@ -98,8 +98,11 @@ public class NettyTcpChannel implements TcpChannel { } else { final Throwable cause = f.cause(); Netty4Utils.maybeDie(cause); - assert cause instanceof Exception; - listener.onFailure((Exception) cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); + } } }); channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java new file mode 100644 index 00000000000..15a0850f64d --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java @@ -0,0 +1,148 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class Netty4CorsTests extends ESTestCase { + + public void testCorsEnabledWithoutAllowOrigins() { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() { + final String originValue = Netty4CorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { + // construct request and send it over the transport layer + final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + embeddedChannel.pipeline().addLast(new Netty4CorsHandler(Netty4HttpServerTransport.buildCorsConfig(settings))); + Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest, 0); + embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content"))); + return embeddedChannel.readOutbound(); + } +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java deleted file mode 100644 index 7c5b35a3229..00000000000 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java +++ /dev/null @@ -1,616 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.netty4; - -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelConfig; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelId; -import io.netty.channel.ChannelMetadata; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelProgressivePromise; -import io.netty.channel.ChannelPromise; -import io.netty.channel.EventLoop; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.ReleasablePagedBytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ByteArray; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.NullDispatcher; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.netty4.Netty4Utils; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.SocketAddress; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; - -public class Netty4HttpChannelTests extends ESTestCase { - - private NetworkService networkService; - private ThreadPool threadPool; - private MockBigArrays bigArrays; - - @Before - public void setup() throws Exception { - networkService = new NetworkService(Collections.emptyList()); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(Netty4Utils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = Netty4CorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - // send a response - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - HttpResponse response = ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - } - - public void testReleaseOnSendToClosedChannel() { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse response = new TestResponse(bigArrays); - assertThat(response.content(), instanceOf(Releasable.class)); - embeddedChannel.close(); - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testReleaseOnSendToChannelAfterException() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, embeddedChannel); - - // send a response, the channel close status should match - assertTrue(embeddedChannel.isOpen()); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - assertThat(embeddedChannel.isOpen(), equalTo(!close)); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = - new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - return ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - } - } - - private static class WriteCapturingChannel implements Channel { - - private List writtenObjects = new ArrayList<>(); - - @Override - public ChannelId id() { - return null; - } - - @Override - public EventLoop eventLoop() { - return null; - } - - @Override - public Channel parent() { - return null; - } - - @Override - public ChannelConfig config() { - return null; - } - - @Override - public boolean isOpen() { - return false; - } - - @Override - public boolean isRegistered() { - return false; - } - - @Override - public boolean isActive() { - return false; - } - - @Override - public ChannelMetadata metadata() { - return null; - } - - @Override - public SocketAddress localAddress() { - return null; - } - - @Override - public SocketAddress remoteAddress() { - return null; - } - - @Override - public ChannelFuture closeFuture() { - return null; - } - - @Override - public boolean isWritable() { - return false; - } - - @Override - public long bytesBeforeUnwritable() { - return 0; - } - - @Override - public long bytesBeforeWritable() { - return 0; - } - - @Override - public Unsafe unsafe() { - return null; - } - - @Override - public ChannelPipeline pipeline() { - return null; - } - - @Override - public ByteBufAllocator alloc() { - return null; - } - - @Override - public Channel read() { - return null; - } - - @Override - public Channel flush() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture disconnect() { - return null; - } - - @Override - public ChannelFuture close() { - return null; - } - - @Override - public ChannelFuture deregister() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture disconnect(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture close(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture deregister(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture write(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture write(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelPromise newPromise() { - return null; - } - - @Override - public ChannelProgressivePromise newProgressivePromise() { - return null; - } - - @Override - public ChannelFuture newSucceededFuture() { - return null; - } - - @Override - public ChannelFuture newFailedFuture(Throwable cause) { - return null; - } - - @Override - public ChannelPromise voidPromise() { - return null; - } - - @Override - public Attribute attr(AttributeKey key) { - return null; - } - - @Override - public boolean hasAttr(AttributeKey key) { - return false; - } - - @Override - public int compareTo(Channel o) { - return 0; - } - - List getWrittenObjects() { - return writtenObjects; - } - - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = Netty4Utils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - TestResponse(final BigArrays bigArrays) { - final byte[] bytes; - try { - bytes = "content".getBytes("UTF-8"); - } catch (final UnsupportedEncodingException e) { - throw new AssertionError(e); - } - final ByteArray bigArray = bigArrays.newByteArray(bytes.length); - bigArray.set(0, bytes, 0, bytes.length); - reference = new ReleasablePagedBytesReference(bigArrays, bigArray, bytes.length, Releasables.releaseOnce(bigArray)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } - -} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index f6c5dfd5a50..8b3ba19fe01 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -191,11 +190,11 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - Netty4HttpResponse resp = new Netty4HttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -233,10 +232,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -246,9 +245,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -260,7 +260,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java index f2b28b90918..3101f660d05 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -26,22 +26,20 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -120,7 +118,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { @Override public ChannelHandler configureServerChannelHandler() { - return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext()); + return new CustomHttpChannelHandler(this, executorService); } @Override @@ -135,8 +133,8 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { private final ExecutorService executorService; - CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) { - super(transport, transport.httpHandlingSettings, threadContext); + CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService) { + super(transport, transport.handlingSettings); this.executorService = executorService; } @@ -187,8 +185,9 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); - httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); + Netty4HttpRequest httpRequest = new Netty4HttpRequest(fullHttpRequest, pipelinedRequest.getSequence()); + Netty4HttpResponse response = httpRequest.createResponse(RestStatus.OK, new BytesArray(uri.getBytes(StandardCharsets.UTF_8))); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); final boolean slow = uri.matches("/slow/\\d+"); if (slow) { @@ -202,7 +201,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { } final ChannelPromise promise = ctx.newPromise(); - ctx.writeAndFlush(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.writeAndFlush(response, promise); } } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 5b22409b92d..bcf28506143 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -291,40 +291,6 @@ public class Netty4HttpServerTransportTests extends ESTestCase { assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (Netty4HttpServerTransport transport = - new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - public void testReadTimeout() throws Exception { final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index 05f28e8254a..ea75c62dbbc 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -23,54 +23,38 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; -import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; -import org.elasticsearch.rest.RestRequest; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.function.BiConsumer; - public class HttpReadWriteHandler implements ReadWriteHandler { private final NettyAdaptor adaptor; - private final NioSocketChannel nioChannel; + private final NioHttpChannel nioHttpChannel; private final NioHttpServerTransport transport; - private final HttpHandlingSettings settings; - private final NamedXContentRegistry xContentRegistry; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, - NamedXContentRegistry xContentRegistry, NioCorsConfig corsConfig, ThreadContext threadContext) { - this.nioChannel = nioChannel; + HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, + NioCorsConfig corsConfig) { + this.nioHttpChannel = nioHttpChannel; this.transport = transport; - this.settings = settings; - this.xContentRegistry = xContentRegistry; - this.corsConfig = corsConfig; - this.threadContext = threadContext; List handlers = new ArrayList<>(5); HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(), @@ -89,7 +73,7 @@ public class HttpReadWriteHandler implements ReadWriteHandler { handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents())); adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); - adaptor.addCloseListener((v, e) -> nioChannel.close()); + adaptor.addCloseListener((v, e) -> nioHttpChannel.close()); } @Override @@ -150,95 +134,22 @@ public class HttpReadWriteHandler implements ReadWriteHandler { request.headers(), request.trailingHeaders()); - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final NioHttpRequest httpRequest; - { - NioHttpRequest innerHttpRequest; - try { - innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copiedRequest); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of - * these parameter values. - */ - final NioHttpChannel channel; - { - NioHttpChannel innerChannel; - int sequence = pipelinedRequest.getSequence(); - BigArrays bigArrays = transport.getBigArrays(); - try { - innerChannel = new NioHttpChannel(nioChannel, bigArrays, httpRequest, sequence, settings, corsConfig, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final NioHttpRequest innerRequest = - new NioHttpRequest( - xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copiedRequest.uri(), - copiedRequest); - innerChannel = new NioHttpChannel(nioChannel, bigArrays, innerRequest, sequence, settings, corsConfig, threadContext); - } - channel = innerChannel; - } + NioHttpRequest httpRequest = new NioHttpRequest(copiedRequest, pipelinedRequest.getSequence()); if (request.decoderResult().isFailure()) { - transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - transport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + transport.incomingRequestError(httpRequest, nioHttpChannel, new Exception(cause)); + } else { + transport.incomingRequestError(httpRequest, nioHttpChannel, (Exception) cause); + } } else { - transport.dispatchRequest(httpRequest, channel); + transport.incomingRequest(httpRequest, nioHttpChannel); } } finally { // As we have copied the buffer, we can release the request request.release(); } } - - private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) { - final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); - headersWithoutContentTypeHeader.add(request.headers()); - headersWithoutContentTypeHeader.remove("Content-Type"); - final FullHttpRequest requestWithoutContentTypeHeader = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - request.content(), - headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again - request.trailingHeaders()); // Content-Type can not be a trailing header - try { - return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader); - } catch (final RestRequest.BadParameterException e) { - badRequestCause.addSuppressed(e); - return requestWithoutParameters(requestWithoutContentTypeHeader); - } - } - - private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) { - // remove all parameters as at least one is incorrectly encoded - return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request); - } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java index 634421b34ea..088f0e85dde 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java @@ -19,244 +19,21 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.Cookie; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import java.util.ArrayList; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.BiConsumer; +import java.io.IOException; +import java.nio.channels.SocketChannel; -public class NioHttpChannel extends AbstractRestChannel { +public class NioHttpChannel extends NioSocketChannel implements HttpChannel { - private final BigArrays bigArrays; - private final int sequence; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - private final FullHttpRequest nettyRequest; - private final NioSocketChannel nioChannel; - private final boolean resetCookies; - - NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, int sequence, - HttpHandlingSettings settings, NioCorsConfig corsConfig, ThreadContext threadContext) { - super(request, settings.getDetailedErrorsEnabled()); - this.nioChannel = nioChannel; - this.bigArrays = bigArrays; - this.sequence = sequence; - this.corsConfig = corsConfig; - this.threadContext = threadContext; - this.nettyRequest = request.getRequest(); - this.resetCookies = settings.isResetCookies(); + NioHttpChannel(SocketChannel socketChannel) throws IOException { + super(socketChannel); } - @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = ByteBufUtils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - ArrayList toClose = new ArrayList<>(3); - - boolean success = false; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - BytesReference content = response.content(); - if (content instanceof Releasable) { - toClose.add((Releasable) content); - } - BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); - if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { - toClose.add((Releasable) bytesStreamOutput); - } - - if (isCloseConnection()) { - toClose.add(nioChannel::close); - } - - BiConsumer listener = (aVoid, ex) -> Releasables.close(toClose); - nioChannel.getContext().sendMessage(new NioHttpResponse(sequence, resp), listener); - success = true; - } finally { - if (success == false) { - Releasables.close(toClose); - } - } - } - - @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(bigArrays); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } - } - - private void addCookies(HttpResponse resp) { - if (resetCookies) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); - } - } - } - } - - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); - } - - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); - } - - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public void sendResponse(HttpResponse response, ActionListener listener) { + getContext().sendMessage(response, ActionListener.toBiConsumer(listener)); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java index 1eb63364f99..977092ddac0 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java @@ -68,7 +68,7 @@ public class NioHttpPipeliningHandler extends ChannelDuplexHandler { List> readyResponses = aggregator.write(response, listener); success = true; for (Tuple responseToWrite : readyResponses) { - ctx.write(responseToWrite.v1().getResponse(), responseToWrite.v2()); + ctx.write(responseToWrite.v1(), responseToWrite.v2()); } } catch (IllegalStateException e) { ctx.channel().close(); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java index 4dcd6ba19e0..08937593f3b 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java @@ -19,13 +19,20 @@ package org.elasticsearch.http.nio; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import java.util.AbstractMap; import java.util.Collection; @@ -35,25 +42,17 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -public class NioHttpRequest extends RestRequest { +public class NioHttpRequest implements HttpRequest { private final FullHttpRequest request; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); - this.request = request; - if (request.content().isReadable()) { - this.content = ByteBufUtils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - - } - - NioHttpRequest(NamedXContentRegistry xContentRegistry, Map params, String uri, FullHttpRequest request) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); + NioHttpRequest(FullHttpRequest request, int sequence) { this.request = request; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = ByteBufUtils.toBytesReference(request.content()); } else { @@ -62,38 +61,38 @@ public class NioHttpRequest extends RestRequest { } @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -104,20 +103,66 @@ public class NioHttpRequest extends RestRequest { return request.uri(); } - @Override - public boolean hasContent() { - return content.length() > 0; - } - @Override public BytesReference content() { return content; } - public FullHttpRequest getRequest() { + + @Override + public final Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } + } + + @Override + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new NioHttpRequest(requestWithoutHeader, sequence); + } + + @Override + public NioHttpResponse createResponse(RestStatus status, BytesReference content) { + return new NioHttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { return request; } + int sequence() { + return sequence; + } + /** * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications * and due to the underlying implementation, it performs case insensitive lookups of key to values. diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java index 4b634994b45..24de843dcc8 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java @@ -19,19 +19,100 @@ package org.elasticsearch.http.nio; -import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedMessage; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.rest.RestStatus; -public class NioHttpResponse extends HttpPipelinedMessage { +import java.util.Collections; +import java.util.EnumMap; +import java.util.Map; - private final FullHttpResponse response; +public class NioHttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage { - public NioHttpResponse(int sequence, FullHttpResponse response) { - super(sequence); - this.response = response; + private final int sequence; + private final NioHttpRequest request; + + NioHttpResponse(NioHttpRequest request, RestStatus status, BytesReference content) { + super(request.nettyRequest().protocolVersion(), getStatus(status), ByteBufUtils.toByteBuf(content)); + this.sequence = request.sequence(); + this.request = request; } - public FullHttpResponse getResponse() { - return response; + @Override + public void addHeader(String name, String value) { + headers().add(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers().contains(name); + } + + @Override + public int getSequence() { + return sequence; + } + + private static Map MAP; + + public NioHttpRequest getRequest() { + return request; + } + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 57aaebb16a1..5aac491a6ab 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -42,7 +42,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; -import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.HttpStats; import org.elasticsearch.http.nio.cors.NioCorsConfig; @@ -53,11 +52,11 @@ import org.elasticsearch.nio.EventHandler; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannel; import org.elasticsearch.nio.NioGroup; +import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.nio.NioSelector; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -104,12 +103,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), (s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope); - private final BigArrays bigArrays; - private final ThreadPool threadPool; - private final NamedXContentRegistry xContentRegistry; - - private final HttpHandlingSettings httpHandlingSettings; - private final boolean tcpNoDelay; private final boolean tcpKeepAlive; private final boolean reuseAddress; @@ -124,16 +117,12 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); - this.bigArrays = bigArrays; - this.threadPool = threadPool; - this.xContentRegistry = xContentRegistry; + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);; this.corsConfig = buildCorsConfig(settings); this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); @@ -148,10 +137,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents); } - BigArrays getBigArrays() { - return bigArrays; - } - public Logger getLogger() { return logger; } @@ -335,17 +320,17 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { socketChannels.add(socketChannel); } - private class HttpChannelFactory extends ChannelFactory { + private class HttpChannelFactory extends ChannelFactory { private HttpChannelFactory() { super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); } @Override - public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { - NioSocketChannel nioChannel = new NioSocketChannel(channel); + public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + NioHttpChannel nioChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this, - httpHandlingSettings, xContentRegistry, corsConfig, threadPool.getThreadContext()); + handlingSettings, corsConfig); Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline, InboundChannelBuffer.allocatingInstance()); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java index 63585107037..98ae2d523ca 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java @@ -22,6 +22,7 @@ package org.elasticsearch.http.nio.cors; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.nio.NioHttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public class NioCorsHandler extends ChannelDuplexHandler { ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass(); + NioHttpResponse response = (NioHttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise); + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 6ad53521ee1..5bda7e1b83d 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -23,29 +23,31 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; -import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder; +import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -55,6 +57,9 @@ import java.nio.ByteBuffer; import java.util.List; import java.util.function.BiConsumer; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; @@ -64,7 +69,12 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEAD import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -72,7 +82,7 @@ import static org.mockito.Mockito.verify; public class HttpReadWriteHandlerTests extends ESTestCase { private HttpReadWriteHandler handler; - private NioSocketChannel nioSocketChannel; + private NioHttpChannel nioHttpChannel; private NioHttpServerTransport transport; private final RequestEncoder requestEncoder = new RequestEncoder(); @@ -96,15 +106,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase { SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings), SETTING_PIPELINING_MAX_EVENTS.getDefault(settings), SETTING_CORS_ENABLED.getDefault(settings)); - ThreadContext threadContext = new ThreadContext(settings); - nioSocketChannel = mock(NioSocketChannel.class); - handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY, - NioCorsConfigBuilder.forAnyOrigin().build(), threadContext); + nioHttpChannel = mock(NioHttpChannel.class); + handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build()); } public void testSuccessfulDecodeHttpRequest() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); int slicePoint = randomInt(buf.writerIndex() - 1); @@ -113,22 +121,21 @@ public class HttpReadWriteHandlerTests extends ESTestCase { ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); handler.consumeReads(toChannelBuffer(slicedBuf2)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest(); - assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion()); - assertEquals(httpRequest.method(), nettyHttpRequest.method()); + HttpRequest nioHttpRequest = requestCaptor.getValue(); + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + assertEquals(RestRequest.Method.GET, nioHttpRequest.method()); } public void testDecodeHttpRequestError() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); buf.setByte(0, ' '); @@ -137,15 +144,15 @@ public class HttpReadWriteHandlerTests extends ESTestCase { handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); - verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture()); assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); } public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); HttpUtil.setContentLength(httpRequest, 1025); HttpUtil.setKeepAlive(httpRequest, false); @@ -153,60 +160,176 @@ public class HttpReadWriteHandlerTests extends ESTestCase { handler.consumeReads(toChannelBuffer(buf)); - verify(transport, times(0)).dispatchBadRequest(any(), any(), any()); - verify(transport, times(0)).dispatchRequest(any(), any()); + verify(transport, times(0)).incomingRequestError(any(), any(), any()); + verify(transport, times(0)).incomingRequest(any(), any()); List flushOperations = handler.pollFlushOperations(); assertFalse(flushOperations.isEmpty()); FlushOperation flushOperation = flushOperations.get(0); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); flushOperation.getListener().accept(null, null); // Since we have keep-alive set to false, we should close the channel after the response has been // flushed - verify(nioSocketChannel).close(); + verify(nioHttpChannel).close(); } @SuppressWarnings("unchecked") public void testEncodeHttpResponse() throws IOException { prepareHandlerForResponse(handler); - FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); - NioHttpResponse pipelinedResponse = new NioHttpResponse(0, fullHttpResponse); + DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); + httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0"); SocketChannelContext context = mock(SocketChannelContext.class); - HttpWriteOperation writeOperation = new HttpWriteOperation(context, pipelinedResponse, mock(BiConsumer.class)); + HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class)); List flushOperations = handler.writeToBytes(writeOperation); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); assertEquals(HttpResponseStatus.OK, response.status()); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); } - private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException { - HttpMethod method = HttpMethod.GET; - HttpVersion version = HttpVersion.HTTP_1_1; + public void testCorsEnabledWithoutAllowOrigins() throws IOException { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() throws IOException { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() throws IOException { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + FullHttpResponse response = executeCorsRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() throws IOException { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() throws IOException { + final String originValue = NioCorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException { + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings); + HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig); + prepareHandlerForResponse(handler); + DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + NioHttpRequest nioHttpRequest = new NioHttpRequest(httpRequest, 0); + BytesArray content = new BytesArray("content"); + HttpResponse response = nioHttpRequest.createResponse(RestStatus.OK, content); + response.addHeader("Content-Length", Integer.toString(content.length())); + + SocketChannelContext context = mock(SocketChannelContext.class); + List flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {})); + + FlushOperation flushOperation = flushOperations.get(0); + return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + } + + + + private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException { + HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD; + HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1; String uri = "http://localhost:9090/" + randomAlphaOfLength(8); - HttpRequest request = new DefaultFullHttpRequest(version, method, uri); + io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri); ByteBuf buf = requestEncoder.encode(request); handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class); + verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest requestParsed = nioHttpRequest.getRequest(); - assertNotNull(requestParsed); - assertEquals(requestParsed.method(), method); - assertEquals(requestParsed.protocolVersion(), version); - assertEquals(requestParsed.uri(), uri); - return requestParsed; + NioHttpRequest nioHttpRequest = requestCaptor.getValue(); + assertNotNull(nioHttpRequest); + assertEquals(method.name(), nioHttpRequest.method().name()); + if (version == HttpVersion.HTTP_1_1) { + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + } else { + assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion()); + } + assertEquals(nioHttpRequest.uri(), uri); + return nioHttpRequest; } private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { @@ -226,11 +349,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase { return buffer; } + private static final int MAX = 16 * 1024 * 1024; + private static class RequestEncoder { - private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder()); + private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder(), new HttpObjectAggregator(MAX)); - private ByteBuf encode(HttpRequest httpRequest) { + private ByteBuf encode(io.netty.handler.codec.http.HttpRequest httpRequest) { requestEncoder.writeOutbound(httpRequest); return requestEncoder.readOutbound(); } @@ -238,9 +363,9 @@ public class HttpReadWriteHandlerTests extends ESTestCase { private static class ResponseDecoder { - private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder()); + private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(MAX)); - private HttpResponse decode(ByteBuf response) { + private FullHttpResponse decode(ByteBuf response) { responseDecoder.writeInbound(response); return responseDecoder.readInbound(); } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java deleted file mode 100644 index 5fa0a7ae0a6..00000000000 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.nio; - -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.charset.StandardCharsets; -import java.util.function.BiConsumer; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class NioHttpChannelTests extends ESTestCase { - - private ThreadPool threadPool; - private MockBigArrays bigArrays; - private NioSocketChannel nioChannel; - private SocketChannelContext channelContext; - - @Before - public void setup() throws Exception { - nioChannel = mock(NioSocketChannel.class); - channelContext = mock(SocketChannelContext.class); - when(nioChannel.getContext()).thenReturn(channelContext); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(ByteBufUtils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = NioCorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - // send a response - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, corsConfig, - threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext).sendMessage(responseCaptor.capture(), any()); - Object nioResponse = responseCaptor.getValue(); - HttpResponse response = ((NioHttpResponse) nioResponse).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - - @SuppressWarnings("unchecked") - public void testReleaseInListener() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final NioHttpRequest request = new NioHttpRequest(registry, httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - if (randomBoolean()) { - listener.accept(null, null); - } else { - listener.accept(null, new ClosedChannelException()); - } - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - - - @SuppressWarnings("unchecked") - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - listener.accept(null, null); - if (close) { - verify(nioChannel, times(1)).close(); - } else { - verify(nioChannel, times(0)).close(); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, httpHandlingSettings, corsConfig, - threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext, atLeastOnce()).sendMessage(responseCaptor.capture(), any()); - return ((NioHttpResponse) responseCaptor.getValue()).getResponse(); - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = ByteBufUtils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } -} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java index 94d7db171a5..5f2784a3567 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -190,11 +189,11 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - NioHttpResponse resp = new NioHttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -231,10 +230,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -244,9 +243,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -258,7 +258,7 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new NioHttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java index c43fc7d0723..48a5bf617a4 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java @@ -280,40 +280,6 @@ public class NioHttpServerTransportTests extends ESTestCase { assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (NioHttpServerTransport transport = - new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - // public void testReadTimeout() throws Exception { // final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { // diff --git a/qa/build.gradle b/qa/build.gradle index 709c309359e..0336b947d06 100644 --- a/qa/build.gradle +++ b/qa/build.gradle @@ -5,6 +5,20 @@ subprojects { Project subproj -> subproj.tasks.withType(RestIntegTestTask) { subproj.extensions.configure("${it.name}Cluster") { cluster -> cluster.distribution = System.getProperty('tests.distribution', 'oss-zip') + if (cluster.distribution == 'zip') { + /* + * Add Elastic's repositories so we can resolve older versions of the + * default distribution. Those aren't in maven central. + */ + repositories { + maven { + url "https://artifacts.elastic.co/maven" + } + maven { + url "https://snapshots.elastic.co/maven" + } + } + } } } } diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 1757548c28b..4432d864fd3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -50,7 +50,6 @@ import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.threadpool.ThreadPool; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -365,28 +364,11 @@ public class MasterService extends AbstractLifecycleComponent { } public Discovery.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) { - ArrayList ackListeners = new ArrayList<>(); - - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - nonFailedTasks.stream().filter(task -> task.listener instanceof AckedClusterStateTaskListener).forEach(task -> { - final AckedClusterStateTaskListener ackedListener = (AckedClusterStateTaskListener) task.listener; - if (ackedListener.ackTimeout() == null || ackedListener.ackTimeout().millis() == 0) { - ackedListener.onAckTimeout(); - } else { - try { - ackListeners.add(new AckCountDownListener(ackedListener, newClusterState.version(), newClusterState.nodes(), - threadPool)); - } catch (EsRejectedExecutionException ex) { - if (logger.isDebugEnabled()) { - logger.debug("Couldn't schedule timeout thread - node might be shutting down", ex); - } - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - ackedListener.onAckTimeout(); - } - } - }); - - return new DelegatingAckListener(ackListeners); + return new DelegatingAckListener(nonFailedTasks.stream() + .filter(task -> task.listener instanceof AckedClusterStateTaskListener) + .map(task -> new AckCountDownListener((AckedClusterStateTaskListener) task.listener, newClusterState.version(), + newClusterState.nodes(), threadPool)) + .collect(Collectors.toList())); } public boolean clusterStateUnchanged() { @@ -549,6 +531,13 @@ public class MasterService extends AbstractLifecycleComponent { this.listeners = listeners; } + @Override + public void onCommit(TimeValue commitTime) { + for (Discovery.AckListener listener : listeners) { + listener.onCommit(commitTime); + } + } + @Override public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { for (Discovery.AckListener listener : listeners) { @@ -564,14 +553,16 @@ public class MasterService extends AbstractLifecycleComponent { private final AckedClusterStateTaskListener ackedTaskListener; private final CountDown countDown; private final DiscoveryNode masterNode; + private final ThreadPool threadPool; private final long clusterStateVersion; - private final Future ackTimeoutCallback; + private volatile Future ackTimeoutCallback; private Exception lastFailure; AckCountDownListener(AckedClusterStateTaskListener ackedTaskListener, long clusterStateVersion, DiscoveryNodes nodes, ThreadPool threadPool) { this.ackedTaskListener = ackedTaskListener; this.clusterStateVersion = clusterStateVersion; + this.threadPool = threadPool; this.masterNode = nodes.getMasterNode(); int countDown = 0; for (DiscoveryNode node : nodes) { @@ -581,8 +572,27 @@ public class MasterService extends AbstractLifecycleComponent { } } logger.trace("expecting {} acknowledgements for cluster_state update (version: {})", countDown, clusterStateVersion); - this.countDown = new CountDown(countDown); - this.ackTimeoutCallback = threadPool.schedule(ackedTaskListener.ackTimeout(), ThreadPool.Names.GENERIC, () -> onTimeout()); + this.countDown = new CountDown(countDown + 1); // we also wait for onCommit to be called + } + + @Override + public void onCommit(TimeValue commitTime) { + TimeValue ackTimeout = ackedTaskListener.ackTimeout(); + if (ackTimeout == null) { + ackTimeout = TimeValue.ZERO; + } + final TimeValue timeLeft = TimeValue.timeValueNanos(Math.max(0, ackTimeout.nanos() - commitTime.nanos())); + if (timeLeft.nanos() == 0L) { + onTimeout(); + } else if (countDown.countDown()) { + finish(); + } else { + this.ackTimeoutCallback = threadPool.schedule(timeLeft, ThreadPool.Names.GENERIC, this::onTimeout); + // re-check if onNodeAck has not completed while we were scheduling the timeout + if (countDown.isCountedDown()) { + FutureUtils.cancel(ackTimeoutCallback); + } + } } @Override @@ -599,12 +609,16 @@ public class MasterService extends AbstractLifecycleComponent { } if (countDown.countDown()) { - logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); - FutureUtils.cancel(ackTimeoutCallback); - ackedTaskListener.onAllNodesAcked(lastFailure); + finish(); } } + private void finish() { + logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); + FutureUtils.cancel(ackTimeoutCallback); + ackedTaskListener.onAllNodesAcked(lastFailure); + } + public void onTimeout() { if (countDown.fastForward()) { logger.trace("timeout waiting for acknowledgement for cluster_state update (version: {})", clusterStateVersion); diff --git a/server/src/main/java/org/elasticsearch/discovery/Discovery.java b/server/src/main/java/org/elasticsearch/discovery/Discovery.java index 9c708760324..b58f61bac89 100644 --- a/server/src/main/java/org/elasticsearch/discovery/Discovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/Discovery.java @@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.TimeValue; import java.io.IOException; @@ -48,6 +49,19 @@ public interface Discovery extends LifecycleComponent { void publish(ClusterChangedEvent clusterChangedEvent, AckListener ackListener); interface AckListener { + /** + * Should be called when the discovery layer has committed the clusters state (i.e. even if this publication fails, + * it is guaranteed to appear in future publications). + * @param commitTime the time it took to commit the cluster state + */ + void onCommit(TimeValue commitTime); + + /** + * Should be called whenever the discovery layer receives confirmation from a node that it has successfully applied + * the cluster state. In case of failures, an exception should be provided as parameter. + * @param node the node + * @param e the optional exception + */ void onNodeAck(DiscoveryNode node, @Nullable Exception e); } diff --git a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java index cd775e29f5a..d7c37febb5d 100644 --- a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java @@ -30,6 +30,7 @@ import org.elasticsearch.cluster.service.ClusterApplier.ClusterApplyListener; import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoveryStats; import org.elasticsearch.transport.TransportService; @@ -61,6 +62,7 @@ public class SingleNodeDiscovery extends AbstractLifecycleComponent implements D public synchronized void publish(final ClusterChangedEvent event, final AckListener ackListener) { clusterState = event.state(); + ackListener.onCommit(TimeValue.ZERO); CountDownLatch latch = new CountDownLatch(1); ClusterApplyListener listener = new ClusterApplyListener() { diff --git a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java index cd87a415263..5398b2a057a 100644 --- a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java +++ b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java @@ -158,7 +158,8 @@ public class PublishClusterStateAction extends AbstractComponent { } try { - innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, sendFullVersion, serializedStates, serializedDiffs); + innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, ackListener, sendFullVersion, serializedStates, + serializedDiffs); } catch (Discovery.FailedToCommitClusterStateException t) { throw t; } catch (Exception e) { @@ -173,8 +174,9 @@ public class PublishClusterStateAction extends AbstractComponent { } private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final Set nodesToPublishTo, - final SendingController sendingController, final boolean sendFullVersion, - final Map serializedStates, final Map serializedDiffs) { + final SendingController sendingController, final Discovery.AckListener ackListener, + final boolean sendFullVersion, final Map serializedStates, + final Map serializedDiffs) { final ClusterState clusterState = clusterChangedEvent.state(); final ClusterState previousState = clusterChangedEvent.previousState(); @@ -195,8 +197,12 @@ public class PublishClusterStateAction extends AbstractComponent { sendingController.waitForCommit(discoverySettings.getCommitTimeout()); + final long commitTime = System.nanoTime() - publishingStartInNanos; + + ackListener.onCommit(TimeValue.timeValueNanos(commitTime)); + try { - long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - (System.nanoTime() - publishingStartInNanos)); + long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - commitTime); final BlockingClusterStatePublishResponseHandler publishResponseHandler = sendingController.getPublishResponseHandler(); sendingController.setPublishingTimedOut(!publishResponseHandler.awaitAllNodes(TimeValue.timeValueNanos(timeLeftInNanos))); if (sendingController.getPublishingTimedOut()) { diff --git a/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java b/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java index 0821b176e75..e048512e638 100644 --- a/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java +++ b/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java @@ -29,6 +29,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.OutputStreamIndexOutput; import org.apache.lucene.store.SimpleFSDirectory; +import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.bytes.BytesArray; @@ -76,6 +77,7 @@ public abstract class MetaDataStateFormat { private final String prefix; private final Pattern stateFilePattern; + private static final Logger logger = Loggers.getLogger(MetaDataStateFormat.class); /** * Creates a new {@link MetaDataStateFormat} instance @@ -134,6 +136,7 @@ public abstract class MetaDataStateFormat { IOUtils.fsync(tmpStatePath, false); // fsync the state file Files.move(tmpStatePath, finalStatePath, StandardCopyOption.ATOMIC_MOVE); IOUtils.fsync(stateLocation, true); + logger.trace("written state to {}", finalStatePath); for (int i = 1; i < locations.length; i++) { stateLocation = locations[i].resolve(STATE_DIR_NAME); Files.createDirectories(stateLocation); @@ -145,12 +148,15 @@ public abstract class MetaDataStateFormat { // we are on the same FileSystem / Partition here we can do an atomic move Files.move(tmpPath, finalPath, StandardCopyOption.ATOMIC_MOVE); IOUtils.fsync(stateLocation, true); + logger.trace("copied state to {}", finalPath); } finally { Files.deleteIfExists(tmpPath); + logger.trace("cleaned up {}", tmpPath); } } } finally { Files.deleteIfExists(tmpStatePath); + logger.trace("cleaned up {}", tmpStatePath); } cleanupOldFiles(prefix, fileName, locations); } @@ -211,20 +217,19 @@ public abstract class MetaDataStateFormat { } private void cleanupOldFiles(final String prefix, final String currentStateFile, Path[] locations) throws IOException { - final DirectoryStream.Filter filter = new DirectoryStream.Filter() { - @Override - public boolean accept(Path entry) throws IOException { - final String entryFileName = entry.getFileName().toString(); - return Files.isRegularFile(entry) - && entryFileName.startsWith(prefix) // only state files - && currentStateFile.equals(entryFileName) == false; // keep the current state file around - } + final DirectoryStream.Filter filter = entry -> { + final String entryFileName = entry.getFileName().toString(); + return Files.isRegularFile(entry) + && entryFileName.startsWith(prefix) // only state files + && currentStateFile.equals(entryFileName) == false; // keep the current state file around }; // now clean up the old files for (Path dataLocation : locations) { + logger.trace("cleanupOldFiles: cleaning up {}", dataLocation); try (DirectoryStream stream = Files.newDirectoryStream(dataLocation.resolve(STATE_DIR_NAME), filter)) { for (Path stateFile : stream) { Files.deleteIfExists(stateFile); + logger.trace("cleanupOldFiles: cleaned up {}", stateFile); } } } diff --git a/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java b/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java index 00b981175f2..fd1698bb006 100644 --- a/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java +++ b/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java @@ -123,6 +123,7 @@ public class MetaStateService extends AbstractComponent { try { IndexMetaData.FORMAT.write(indexMetaData, nodeEnv.indexPaths(indexMetaData.getIndex())); + logger.trace("[{}] state written", index); } catch (Exception ex) { logger.warn(() -> new ParameterizedMessage("[{}]: failed to write index state", index), ex); throw new IOException("failed to write state for [" + index + "]", ex); @@ -136,6 +137,7 @@ public class MetaStateService extends AbstractComponent { logger.trace("[_global] writing state, reason [{}]", reason); try { MetaData.FORMAT.write(metaData, nodeEnv.nodeDataPaths()); + logger.trace("[_global] state written"); } catch (Exception ex) { logger.warn("[_global]: failed to write global state", ex); throw new IOException("failed to write global state", ex); diff --git a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java index c75754bde58..4fad4159f55 100644 --- a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java @@ -21,6 +21,7 @@ package org.elasticsearch.http; import com.carrotsearch.hppc.IntHashSet; import com.carrotsearch.hppc.IntSet; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.network.NetworkService; @@ -29,7 +30,9 @@ import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.PortsRange; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.threadpool.ThreadPool; @@ -48,11 +51,14 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; -public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport { +public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport { + public final HttpHandlingSettings handlingSettings; protected final NetworkService networkService; + protected final BigArrays bigArrays; protected final ThreadPool threadPool; protected final Dispatcher dispatcher; + private final NamedXContentRegistry xContentRegistry; protected final String[] bindHosts; protected final String[] publishHosts; @@ -61,11 +67,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected volatile BoundTransportAddress boundAddress; - protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) { + protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { super(settings); this.networkService = networkService; + this.bigArrays = bigArrays; this.threadPool = threadPool; + this.xContentRegistry = xContentRegistry; this.dispatcher = dispatcher; + this.handlingSettings = HttpHandlingSettings.fromSettings(settings); // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); @@ -156,17 +166,94 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo return publishPort; } - public void dispatchRequest(final RestRequest request, final RestChannel channel) { + /** + * This method handles an incoming http request. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + */ + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { + handleIncomingRequest(httpRequest, httpChannel, null); + } + + /** + * This method handles an incoming http request that has encountered an error. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + * @param exception that was encountered + */ + public void incomingRequestError(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + handleIncomingRequest(httpRequest, httpChannel, exception); + } + + // Visible for testing + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchRequest(request, channel, threadContext); + if (badRequestCause != null) { + dispatcher.dispatchBadRequest(restRequest, channel, threadContext, badRequestCause); + } else { + dispatcher.dispatchRequest(restRequest, channel, threadContext); + } } } - public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchBadRequest(request, channel, threadContext, cause); + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + Exception badRequestCause = exception; + + /* + * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there + * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we + * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, + * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the + * underlying exception that caused us to treat the request as bad. + */ + final RestRequest restRequest; + { + RestRequest innerRestRequest; + try { + innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel); + } catch (final RestRequest.ContentTypeHeaderException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause); + } catch (final RestRequest.BadParameterException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + } + restRequest = innerRestRequest; + } + + /* + * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid + * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an + * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these + * parameter values. + */ + final RestChannel channel; + { + RestChannel innerChannel; + ThreadContext threadContext = threadPool.getThreadContext(); + try { + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext); + } catch (final IllegalArgumentException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext); + } + channel = innerChannel; + } + + dispatchRequest(restRequest, channel, badRequestCause); + } + + private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) { + HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); + try { + return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel); + } catch (final RestRequest.BadParameterException e) { + badRequestCause.addSuppressed(e); + return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } } } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java new file mode 100644 index 00000000000..f5924bb239e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -0,0 +1,172 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * The default rest channel for incoming requests. This class implements the basic logic for sending a rest + * response. It will set necessary headers nad ensure that bytes are released after the response is sent. + */ +public class DefaultRestChannel extends AbstractRestChannel implements RestChannel { + + static final String CLOSE = "close"; + static final String CONNECTION = "connection"; + static final String KEEP_ALIVE = "keep-alive"; + static final String CONTENT_TYPE = "content-type"; + static final String CONTENT_LENGTH = "content-length"; + static final String SET_COOKIE = "set-cookie"; + static final String X_OPAQUE_ID = "X-Opaque-Id"; + + private final HttpRequest httpRequest; + private final BigArrays bigArrays; + private final HttpHandlingSettings settings; + private final ThreadContext threadContext; + private final HttpChannel httpChannel; + + DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays, + HttpHandlingSettings settings, ThreadContext threadContext) { + super(request, settings.getDetailedErrorsEnabled()); + this.httpChannel = httpChannel; + this.httpRequest = httpRequest; + this.bigArrays = bigArrays; + this.settings = settings; + this.threadContext = threadContext; + } + + @Override + protected BytesStreamOutput newBytesOutput() { + return new ReleasableBytesStreamOutput(bigArrays); + } + + @Override + public void sendResponse(RestResponse restResponse) { + HttpResponse httpResponse; + if (RestRequest.Method.HEAD == request.method()) { + httpResponse = httpRequest.createResponse(restResponse.status(), BytesArray.EMPTY); + } else { + httpResponse = httpRequest.createResponse(restResponse.status(), restResponse.content()); + } + + // TODO: Ideally we should move the setting of Cors headers into :server + // NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); + + String opaque = request.header(X_OPAQUE_ID); + if (opaque != null) { + setHeaderField(httpResponse, X_OPAQUE_ID, opaque); + } + + // Add all custom headers + addCustomHeaders(httpResponse, restResponse.getHeaders()); + addCustomHeaders(httpResponse, threadContext.getResponseHeaders()); + + ArrayList toClose = new ArrayList<>(3); + + boolean success = false; + try { + // If our response doesn't specify a content-type header, set one + setHeaderField(httpResponse, CONTENT_TYPE, restResponse.contentType(), false); + // If our response has no content-length, calculate and set one + setHeaderField(httpResponse, CONTENT_LENGTH, String.valueOf(restResponse.content().length()), false); + + addCookies(httpResponse); + + BytesReference content = restResponse.content(); + if (content instanceof Releasable) { + toClose.add((Releasable) content); + } + BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); + if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { + toClose.add((Releasable) bytesStreamOutput); + } + + if (isCloseConnection()) { + toClose.add(httpChannel); + } + + ActionListener listener = ActionListener.wrap(() -> Releasables.close(toClose)); + httpChannel.sendResponse(httpResponse, listener); + success = true; + } finally { + if (success == false) { + Releasables.close(toClose); + } + } + + } + + private void setHeaderField(HttpResponse response, String headerField, String value) { + setHeaderField(response, headerField, value, true); + } + + private void setHeaderField(HttpResponse response, String headerField, String value, boolean override) { + if (override || !response.containsHeader(headerField)) { + response.addHeader(headerField, value); + } + } + + private void addCustomHeaders(HttpResponse response, Map> customHeaders) { + if (customHeaders != null) { + for (Map.Entry> headerEntry : customHeaders.entrySet()) { + for (String headerValue : headerEntry.getValue()) { + setHeaderField(response, headerEntry.getKey(), headerValue); + } + } + } + } + + private void addCookies(HttpResponse response) { + if (settings.isResetCookies()) { + List cookies = request.getHttpRequest().strictCookies(); + if (cookies.isEmpty() == false) { + for (String cookie : cookies) { + response.addHeader(SET_COOKIE, cookie); + } + } + } + } + + // Determine if the request connection should be closed on completion. + private boolean isCloseConnection() { + final boolean http10 = isHttp10(); + return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); + } + + // Determine if the request protocol version is HTTP 1.0 + private boolean isHttp10() { + return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + } +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpChannel.java b/server/src/main/java/org/elasticsearch/http/HttpChannel.java new file mode 100644 index 00000000000..baea3e0c3b3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpChannel.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.lease.Releasable; + +import java.net.InetSocketAddress; + +public interface HttpChannel extends Releasable { + + /** + * Sends a http response to the channel. The listener will be executed once the send process has been + * completed. + * + * @param response to send to channel + * @param listener to execute upon send completion + */ + void sendResponse(HttpResponse response, ActionListener listener); + + /** + * Returns the local address for this channel. + * + * @return the local address of this channel. + */ + InetSocketAddress getLocalAddress(); + + /** + * Returns the remote address for this channel. Can be null if channel does not have a remote address. + * + * @return the remote address of this channel. + */ + InetSocketAddress getRemoteAddress(); + + /** + * Closes the channel. This might be an asynchronous process. There is no guarantee that the channel + * will be closed when this method returns. + */ + void close(); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java index 7db8666e73a..ae1520cba60 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java @@ -18,20 +18,17 @@ */ package org.elasticsearch.http; -public class HttpPipelinedMessage implements Comparable { +public interface HttpPipelinedMessage extends Comparable { - private final int sequence; - - public HttpPipelinedMessage(int sequence) { - this.sequence = sequence; - } - - public int getSequence() { - return sequence; - } + /** + * Get the sequence number for this message. + * + * @return the sequence number + */ + int getSequence(); @Override - public int compareTo(HttpPipelinedMessage o) { - return Integer.compare(sequence, o.sequence); + default int compareTo(HttpPipelinedMessage o) { + return Integer.compare(getSequence(), o.getSequence()); } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java index df8bd7ee1eb..db3a2bae167 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java @@ -18,15 +18,21 @@ */ package org.elasticsearch.http; -public class HttpPipelinedRequest extends HttpPipelinedMessage { +public class HttpPipelinedRequest implements HttpPipelinedMessage { private final R request; + private final int sequence; HttpPipelinedRequest(int sequence, R request) { - super(sequence); + this.sequence = sequence; this.request = request; } + @Override + public int getSequence() { + return sequence; + } + public R getRequest() { return request; } diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java new file mode 100644 index 00000000000..496fec23312 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; + +import java.util.List; +import java.util.Map; + +/** + * A basic http request abstraction. Http modules needs to implement this interface to integrate with the + * server package's rest handling. + */ +public interface HttpRequest { + + enum HttpVersion { + HTTP_1_0, + HTTP_1_1 + } + + RestRequest.Method method(); + + /** + * The uri of the rest request, with the query string. + */ + String uri(); + + BytesReference content(); + + /** + * Get all of the headers and values associated with the headers. Modifications of this map are not supported. + */ + Map> getHeaders(); + + List strictCookies(); + + HttpVersion protocolVersion(); + + HttpRequest removeHeader(String header); + + /** + * Create an http response from this request and the supplied status and content. + */ + HttpResponse createResponse(RestStatus status, BytesReference content); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpResponse.java b/server/src/main/java/org/elasticsearch/http/HttpResponse.java new file mode 100644 index 00000000000..2d363f663c3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpResponse.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +/** + * A basic http response abstraction. Http modules must implement this interface as the server package rest + * handling needs to set http headers for a response. + */ +public interface HttpResponse { + + void addHeader(String name, String value); + + boolean containsHeader(String name); + +} diff --git a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java index d376b65ef2d..4e3d652ec5d 100644 --- a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java +++ b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java @@ -40,7 +40,7 @@ public abstract class AbstractRestChannel implements RestChannel { private static final Predicate EXCLUDE_FILTER = INCLUDE_FILTER.negate(); protected final RestRequest request; - protected final boolean detailedErrorsEnabled; + private final boolean detailedErrorsEnabled; private final String format; private final String filterPath; private final boolean pretty; diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index aae63f041fa..82fcf7178d1 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -272,8 +272,9 @@ public class RestController extends AbstractComponent implements HttpServerTrans */ private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) { if (restRequest.getXContentType() == null) { - if (restHandler.supportsContentStream() && restRequest.header("Content-Type") != null) { - final String lowercaseMediaType = restRequest.header("Content-Type").toLowerCase(Locale.ROOT); + String contentTypeHeader = restRequest.header("Content-Type"); + if (restHandler.supportsContentStream() && contentTypeHeader != null) { + final String lowercaseMediaType = contentTypeHeader.toLowerCase(Locale.ROOT); // we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/ if (lowercaseMediaType.equals("application/x-ndjson")) { restRequest.setXContentType(XContentType.JSON); diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index 65b4f9d1d36..813d6feb551 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -35,10 +35,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; import java.io.IOException; import java.io.InputStream; -import java.net.SocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -51,7 +52,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; -public abstract class RestRequest implements ToXContent.Params { +public class RestRequest implements ToXContent.Params { // tchar pattern as defined by RFC7230 section 3.2.6 private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+"); @@ -62,18 +63,47 @@ public abstract class RestRequest implements ToXContent.Params { private final String rawPath; private final Set consumedParams = new HashSet<>(); private final SetOnce xContentType = new SetOnce<>(); + private final HttpRequest httpRequest; + private final HttpChannel httpChannel; + + protected RestRequest(NamedXContentRegistry xContentRegistry, Map params, String path, + Map> headers, HttpRequest httpRequest, HttpChannel httpChannel) { + final XContentType xContentType; + try { + xContentType = parseContentType(headers.get("Content-Type")); + } catch (final IllegalArgumentException e) { + throw new ContentTypeHeaderException(e); + } + if (xContentType != null) { + this.xContentType.set(xContentType); + } + this.xContentRegistry = xContentRegistry; + this.httpRequest = httpRequest; + this.httpChannel = httpChannel; + this.params = params; + this.rawPath = path; + this.headers = Collections.unmodifiableMap(headers); + } + + protected RestRequest(RestRequest restRequest) { + this(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + } /** - * Creates a new REST request. + * Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be + * decoded * * @param xContentRegistry the content registry - * @param uri the raw URI that will be parsed into the path and the parameters - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws BadParameterException if the parameters can not be decoded * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map> headers) { - this(xContentRegistry, params(uri), path(uri), headers); + public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, HttpChannel httpChannel) { + Map params = params(httpRequest.uri()); + String path = path(httpRequest.uri()); + return new RestRequest(xContentRegistry, params, path, httpRequest.getHeaders(), httpRequest, httpChannel); } private static Map params(final String uri) { @@ -99,46 +129,34 @@ public abstract class RestRequest implements ToXContent.Params { } /** - * Creates a new REST request. In contrast to - * {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw - * a {@link BadParameterException}. + * Creates a new REST request. The path is not decoded so this constructor will not throw a + * {@link BadParameterException}. * * @param xContentRegistry the content registry - * @param params the request parameters - * @param path the raw path (which is not parsed) - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String path, - final Map> headers) { - final XContentType xContentType; - try { - xContentType = parseContentType(headers.get("Content-Type")); - } catch (final IllegalArgumentException e) { - throw new ContentTypeHeaderException(e); - } - if (xContentType != null) { - this.xContentType.set(xContentType); - } - this.xContentRegistry = xContentRegistry; - this.params = params; - this.rawPath = path; - this.headers = Collections.unmodifiableMap(headers); + public static RestRequest requestWithoutParameters(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, + HttpChannel httpChannel) { + Map params = Collections.emptyMap(); + return new RestRequest(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } public enum Method { GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT } - public abstract Method method(); + public Method method() { + return httpRequest.method(); + } /** * The uri of the rest request, with the query string. */ - public abstract String uri(); + public String uri() { + return httpRequest.uri(); + } /** * The non decoded, raw path provided. @@ -154,9 +172,13 @@ public abstract class RestRequest implements ToXContent.Params { return RestUtils.decodeComponent(rawPath()); } - public abstract boolean hasContent(); + public boolean hasContent() { + return content().length() > 0; + } - public abstract BytesReference content(); + public BytesReference content() { + return httpRequest.content(); + } /** * @return content of the request body or throw an exception if the body or content type is missing @@ -216,14 +238,12 @@ public abstract class RestRequest implements ToXContent.Params { this.xContentType.set(xContentType); } - @Nullable - public SocketAddress getRemoteAddress() { - return null; + public HttpChannel getHttpChannel() { + return httpChannel; } - @Nullable - public SocketAddress getLocalAddress() { - return null; + public HttpRequest getHttpRequest() { + return httpRequest; } public final boolean hasParam(String key) { diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 7e031f8d004..d0d6fa752d6 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -20,10 +20,10 @@ package org.elasticsearch.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,8 +31,7 @@ import java.util.Set; public abstract class RestResponse { - protected Map> customHeaders; - + private Map> customHeaders; /** * The response content type. @@ -81,10 +80,13 @@ public abstract class RestResponse { } /** - * Returns custom headers that have been added, or null if none have been set. + * Returns custom headers that have been added. This method should not be used to mutate headers. */ - @Nullable public Map> getHeaders() { - return customHeaders; + if (customHeaders == null) { + return Collections.emptyMap(); + } else { + return customHeaders; + } } } diff --git a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java index 1b747f22687..f75363c7ab5 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java @@ -22,6 +22,7 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -39,6 +40,7 @@ import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.BaseFuture; +import org.elasticsearch.discovery.Discovery; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -65,6 +67,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; @@ -680,6 +683,132 @@ public class MasterServiceTests extends ESTestCase { mockAppender.assertAllExpectationsMatched(); } + public void testAcking() throws InterruptedException { + final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node3 = new DiscoveryNode("node3", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + TimedMasterService timedMasterService = new TimedMasterService(Settings.builder().put("cluster.name", + MasterServiceTests.class.getSimpleName()).build(), threadPool); + ClusterState initialClusterState = ClusterState.builder(new ClusterName(MasterServiceTests.class.getSimpleName())) + .nodes(DiscoveryNodes.builder() + .add(node1) + .add(node2) + .add(node3) + .localNodeId(node1.getId()) + .masterNodeId(node1.getId())) + .blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK).build(); + final AtomicReference> publisherRef = new AtomicReference<>(); + timedMasterService.setClusterStatePublisher((cce, l) -> publisherRef.get().accept(cce, l)); + timedMasterService.setClusterStateSupplier(() -> initialClusterState); + timedMasterService.start(); + + + // check that we don't time out before even committing the cluster state + { + final CountDownLatch latch = new CountDownLatch(1); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + throw new Discovery.FailedToCommitClusterStateException("mock exception"); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return TimeValue.ZERO; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + fail(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + latch.countDown(); + } + + @Override + public void onAckTimeout() { + fail(); + } + }); + + latch.await(); + } + + // check that we timeout if commit took too long + { + final CountDownLatch latch = new CountDownLatch(2); + + final TimeValue ackTimeout = TimeValue.timeValueMillis(randomInt(100)); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + ackListener.onCommit(TimeValue.timeValueMillis(ackTimeout.millis() + randomInt(100))); + ackListener.onNodeAck(node1, null); + ackListener.onNodeAck(node2, null); + ackListener.onNodeAck(node3, null); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return ackTimeout; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + latch.countDown(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + fail(); + } + + @Override + public void onAckTimeout() { + latch.countDown(); + } + }); + + latch.await(); + } + + timedMasterService.close(); + } + static class TimedMasterService extends MasterService { public volatile Long currentTimeOverride = null; diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java index c8e85382994..ac1719269e7 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.node.Node; @@ -815,9 +816,16 @@ public class PublishClusterStateActionTests extends ESTestCase { public static class AssertingAckListener implements Discovery.AckListener { private final List> errors = new CopyOnWriteArrayList<>(); private final CountDownLatch countDown; + private final CountDownLatch commitCountDown; public AssertingAckListener(int nodeCount) { countDown = new CountDownLatch(nodeCount); + commitCountDown = new CountDownLatch(1); + } + + @Override + public void onCommit(TimeValue commitTime) { + commitCountDown.countDown(); } @Override @@ -830,6 +838,7 @@ public class PublishClusterStateActionTests extends ESTestCase { public void await(long timeout, TimeUnit unit) throws InterruptedException { assertThat(awaitErrors(timeout, unit), emptyIterable()); + assertTrue(commitCountDown.await(timeout, unit)); } public List> awaitErrors(long timeout, TimeUnit unit) throws InterruptedException { diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index ee74d98002f..a7629e5f48b 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -19,13 +19,27 @@ package org.elasticsearch.http; +import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkUtils; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import java.io.IOException; +import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static java.net.InetAddress.getByName; @@ -36,6 +50,27 @@ import static org.hamcrest.Matchers.equalTo; public class AbstractHttpServerTransportTests extends ESTestCase { + private NetworkService networkService; + private ThreadPool threadPool; + private MockBigArrays bigArrays; + + @Before + public void setup() throws Exception { + networkService = new NetworkService(Collections.emptyList()); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() throws Exception { + if (threadPool != null) { + threadPool.shutdownNow(); + } + threadPool = null; + networkService = null; + bigArrays = null; + } + public void testHttpPublishPort() throws Exception { int boundPort = randomIntBetween(9000, 9100); int otherBoundPort = randomIntBetween(9200, 9300); @@ -71,6 +106,64 @@ public class AbstractHttpServerTransportTests extends ESTestCase { } } + public void testDispatchDoesNotModifyThreadContext() { + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + + @Override + public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("bar", "baz"); + } + + @Override + public void dispatchBadRequest(final RestRequest request, + final RestChannel channel, + final ThreadContext threadContext, + final Throwable cause) { + threadContext.putHeader("foo_bad", "bar"); + threadContext.putTransient("bar_bad", "baz"); + } + + }; + + try (AbstractHttpServerTransport transport = + new AbstractHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher) { + @Override + protected TransportAddress bindAddress(InetAddress hostAddress) { + return null; + } + + @Override + protected void doStart() { + + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } + + @Override + public HttpStats stats() { + return null; + } + }) { + + transport.dispatchRequest(null, null, null); + assertNull(threadPool.getThreadContext().getHeader("foo")); + assertNull(threadPool.getThreadContext().getTransient("bar")); + + transport.dispatchRequest(null, null, new Exception()); + assertNull(threadPool.getThreadContext().getHeader("foo_bad")); + assertNull(threadPool.getThreadContext().getTransient("bar_bad")); + } + } + private TransportAddress address(String host, int port) throws UnknownHostException { return new TransportAddress(getByName(host), port); } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java new file mode 100644 index 00000000000..bc499ed8a42 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -0,0 +1,444 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class DefaultRestChannelTests extends ESTestCase { + + private ThreadPool threadPool; + private MockBigArrays bigArrays; + private HttpChannel httpChannel; + + @Before + public void setup() { + httpChannel = mock(HttpChannel.class); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() { + if (threadPool != null) { + threadPool.shutdownNow(); + } + } + + public void testResponse() { + final TestResponse response = executeRequest(Settings.EMPTY, "request-host"); + assertThat(response.content(), equalTo(new TestRestResponse().content())); + } + + // TODO: Enable these Cors tests when the Cors logic lives in :server + +// public void testCorsEnabledWithoutAllowOrigins() { +// // Set up a HTTP transport with only the CORS enabled setting +// Settings settings = Settings.builder() +// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, "remote-host", "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); +// } +// +// public void testCorsEnabledWithAllowOrigins() { +// final String originValue = "remote-host"; +// // create a http transport with CORS enabled and allow origin configured +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testCorsAllowOriginWithSameHost() { +// String originValue = "remote-host"; +// String host = "remote-host"; +// // create a http transport with CORS enabled +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, host); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = "http://" + originValue; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue + ":5555"; +// host = host + ":5555"; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue.replace("http", "https"); +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testThatStringLiteralWorksOnMatch() { +// final String originValue = "remote-host"; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") +// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); +// } +// +// public void testThatAnyOriginWorks() { +// final String originValue = NioCorsHandler.ANY_ORIGIN; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); +// } + + public void testHeadersSet() { + Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + TestRestResponse resp = new TestRestResponse(); + final String customHeader = "custom-header"; + final String customHeaderValue = "xyz"; + resp.addHeader(customHeader, customHeaderValue); + channel.sendResponse(resp); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse httpResponse = responseCaptor.getValue(); + Map> headers = httpResponse.headers; + assertNull(headers.get("non-existent-header")); + assertEquals(customHeaderValue, headers.get(customHeader).get(0)); + assertEquals("abc", headers.get(DefaultRestChannel.X_OPAQUE_ID).get(0)); + assertEquals(Integer.toString(resp.content().length()), headers.get(DefaultRestChannel.CONTENT_LENGTH).get(0)); + assertEquals(resp.contentType(), headers.get(DefaultRestChannel.CONTENT_TYPE).get(0)); + } + + public void testCookiesSet() { + Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse nioResponse = responseCaptor.getValue(); + Map> headers = nioResponse.headers; + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie")); + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2")); + } + + @SuppressWarnings("unchecked") + public void testReleaseInListener() throws IOException { + final Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, + JsonXContent.contentBuilder().startObject().endObject()); + assertThat(response.content(), not(instanceOf(Releasable.class))); + + // ensure we have reserved bytes + if (randomBoolean()) { + BytesStreamOutput out = channel.bytesOutput(); + assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); + } else { + try (XContentBuilder builder = channel.newBuilder()) { + // do something builder + builder.startObject().endObject(); + } + } + + channel.sendResponse(response); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released + } + + @SuppressWarnings("unchecked") + public void testConnectionClose() throws Exception { + final Settings settings = Settings.builder().build(); + final HttpRequest httpRequest; + final boolean close = randomBoolean(); + if (randomBoolean()) { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + if (close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE)); + } + } else { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/"); + if (!close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE)); + } + } + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + if (close) { + verify(httpChannel, times(1)).close(); + } else { + verify(httpChannel, times(0)).close(); + } + } + + private TestResponse executeRequest(final Settings settings, final String host) { + return executeRequest(settings, null, host); + } + + private TestResponse executeRequest(final Settings settings, final String originValue, final String host) { + HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + // TODO: These exist for the Cors tests +// if (originValue != null) { +// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); +// } +// httpRequest.headers().add(HttpHeaderNames.HOST, host); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // get the response + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any()); + return responseCaptor.getValue(); + } + + private static class TestRequest implements HttpRequest { + + private final HttpVersion version; + private final RestRequest.Method method; + private final String uri; + private HashMap> headers = new HashMap<>(); + + private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { + + this.version = version; + this.method = method; + this.uri = uri; + } + + @Override + public RestRequest.Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return BytesArray.EMPTY; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Arrays.asList("cookie", "cookie2"); + } + + @Override + public HttpVersion protocolVersion() { + return version; + } + + @Override + public HttpRequest removeHeader(String header) { + throw new UnsupportedOperationException("Do not support removing header on test request."); + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + return new TestResponse(status, content); + } + } + + private static class TestResponse implements HttpResponse { + + private final RestStatus status; + private final BytesReference content; + private final Map> headers = new HashMap<>(); + + TestResponse(RestStatus status, BytesReference content) { + this.status = status; + this.content = content; + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return status; + } + + @Override + public void addHeader(String name, String value) { + if (headers.containsKey(name) == false) { + ArrayList values = new ArrayList<>(); + values.add(value); + headers.put(name, values); + } else { + headers.get(name).add(value); + } + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + } + + private static class TestRestResponse extends RestResponse { + + private final BytesReference content; + + TestRestResponse() { + content = new BytesArray("content".getBytes(StandardCharsets.UTF_8)); + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return RestStatus.OK; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java index a0e6f702030..a80c3b1bd42 100644 --- a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java +++ b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.transport.TransportAddress; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; @@ -165,28 +164,7 @@ public class BytesRestResponseTests extends ESTestCase { public void testResponseWhenPathContainsEncodingError() throws IOException { final String path = "%a"; - final RestRequest request = - new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) { - @Override - public Method method() { - return null; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return false; - } - - @Override - public BytesReference content() { - return null; - } - }; + final RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath(path).build(); final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath())); final RestChannel channel = new DetailedExceptionRestChannel(request); // if we try to decode the path, this will throw an IllegalArgumentException again diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index f36638a4390..a090cc40b68 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -110,21 +110,21 @@ public class RestControllerTests extends ESTestCase { RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build(); final RestController spyRestController = spy(restController); when(spyRestController.getAllHandlers(fakeRequest)) - .thenReturn(new Iterator() { - @Override - public boolean hasNext() { - return false; - } + .thenReturn(new Iterator() { + @Override + public boolean hasNext() { + return false; + } - @Override - public MethodHandlers next() { - return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { - assertEquals("true", threadContext.getHeader("header.1")); - assertEquals("true", threadContext.getHeader("header.2")); - assertNull(threadContext.getHeader("header.3")); - }, RestRequest.Method.GET); - } - }); + @Override + public MethodHandlers next() { + return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { + assertEquals("true", threadContext.getHeader("header.1")); + assertEquals("true", threadContext.getHeader("header.2")); + assertNull(threadContext.getHeader("header.3")); + }, RestRequest.Method.GET); + } + }); AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST); restController.dispatchRequest(fakeRequest, channel, threadContext); // the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the @@ -136,7 +136,7 @@ public class RestControllerTests extends ESTestCase { public void testCanTripCircuitBreaker() throws Exception { RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null, null, circuitBreakerService, - usageService); + usageService); // trip circuit breaker by default controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true)); controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false)); @@ -209,7 +209,7 @@ public class RestControllerTests extends ESTestCase { return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true); }; final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null, - circuitBreakerService, usageService); + circuitBreakerService, usageService); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, Optional.of(handler)); assertTrue(wrapperCalled.get()); @@ -240,7 +240,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestAddsAndFreesBytesOnSuccess() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -252,7 +252,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestAddsAndFreesBytesOnError() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -265,7 +265,7 @@ public class RestControllerTests extends ESTestCase { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); // we will produce an error in the rest handler and one more when sending the error response - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -277,7 +277,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestLimitsBytes() { int contentLength = BREAKER_LIMIT.bytesAsInt() + 1; String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -288,11 +288,11 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequiresContentTypeForRequestsWithContent() { String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt()); - TestRestRequest request = new TestRestRequest("/", content, null); + RestRequest request = testRestRequest("/", content, null); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE); restController = new RestController( Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(), - Collections.emptySet(), null, null, circuitBreakerService, usageService); + Collections.emptySet(), null, null, circuitBreakerService, usageService); restController.registerHandler(RestRequest.Method.GET, "/", (r, c, client) -> c.sendResponse( new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))); @@ -412,8 +412,8 @@ public class RestControllerTests extends ESTestCase { public void testNonStreamingXContentCausesErrorResponse() throws IOException { FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), - XContentType.YAML).withPath("/foo").build(); + .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), + XContentType.YAML).withPath("/foo").build(); AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE); restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() { @Override @@ -457,10 +457,10 @@ public class RestControllerTests extends ESTestCase { final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST); restController.dispatchBadRequest( - fakeRestRequest, - channel, - new ThreadContext(Settings.EMPTY), - randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); + fakeRestRequest, + channel, + new ThreadContext(Settings.EMPTY), + randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); assertTrue(channel.getSendResponseCalled()); assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request")); } @@ -495,7 +495,7 @@ public class RestControllerTests extends ESTestCase { @Override public BoundTransportAddress boundAddress() { TransportAddress transportAddress = buildNewFakeTransportAddress(); - return new BoundTransportAddress(new TransportAddress[] {transportAddress} ,transportAddress); + return new BoundTransportAddress(new TransportAddress[]{transportAddress}, transportAddress); } @Override @@ -547,35 +547,11 @@ public class RestControllerTests extends ESTestCase { } } - private static final class TestRestRequest extends RestRequest { - - private final BytesReference content; - - private TestRestRequest(String path, String content, XContentType xContentType) { - super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ? - Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType()))); - this.content = new BytesArray(content); - } - - @Override - public Method method() { - return Method.GET; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return true; - } - - @Override - public BytesReference content() { - return content; - } - + private static RestRequest testRestRequest(String path, String content, XContentType xContentType) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withPath(path); + builder.withContent(new BytesArray(content), xContentType); + return builder.build(); } } + diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index 1b4bbff7322..3ad9c61de3c 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; import java.io.IOException; import java.util.ArrayList; @@ -44,66 +45,66 @@ import static org.hamcrest.Matchers.instanceOf; public class RestRequestTests extends ESTestCase { public void testContentParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentParser()); + contentRestRequest("", emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", singletonMap("source", "{}")).contentParser()); + contentRestRequest("", singletonMap("source", "{}")).contentParser()); assertEquals("request body is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentParser().map()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap(), emptyMap()).contentParser()); + contentRestRequest("", emptyMap(), emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); } public void testApplyContentParser() throws IOException { - new ContentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); - new ContentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); AtomicReference source = new AtomicReference<>(); - new ContentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); + contentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); assertEquals(emptyMap(), source.get()); } public void testContentOrSourceParam() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParam()); + contentRestRequest("", emptyMap()).contentOrSourceParam()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("{\"foo\": \"stuff\"}"), - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .contentOrSourceParam().v2()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "stuff2").immutableMap()).contentOrSourceParam()); assertEquals("source and source_content_type parameters are required", e.getMessage()); } public void testHasContentOrSourceParam() throws IOException { - assertEquals(false, new ContentRestRequest("", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); + assertEquals(false, contentRestRequest("", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); } public void testContentOrSourceParamParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParamParser()); + contentRestRequest("", emptyMap()).contentOrSourceParamParser()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("", MapBuilder.newMapBuilder() + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map()); } public void testWithContentOrSourceParamParserOrNull() throws IOException { - new ContentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); - new ContentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> + contentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); + contentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); + contentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") + contentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") .put("source", "{}").immutableMap()) .withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); @@ -113,18 +114,18 @@ public class RestRequestTests extends ESTestCase { for (XContentType xContentType : XContentType.values()) { Map> map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaType())); - ContentRestRequest restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + RestRequest restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters())); - restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); } } public void testPlainTextSupport() { - ContentRestRequest restRequest = new ContentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), + RestRequest restRequest = contentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8")))); assertNull(restRequest.getXContentType()); @@ -136,7 +137,7 @@ public class RestRequestTests extends ESTestCase { RestRequest.ContentTypeHeaderException.class, () -> { final Map> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type)); - new ContentRestRequest("", Collections.emptyMap(), headers); + contentRestRequest("", Collections.emptyMap(), headers); }); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf(IllegalArgumentException.class)); @@ -144,7 +145,7 @@ public class RestRequestTests extends ESTestCase { } public void testNoContentTypeHeader() { - ContentRestRequest contentRestRequest = new ContentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); + RestRequest contentRestRequest = contentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); assertNull(contentRestRequest.getXContentType()); } @@ -152,7 +153,7 @@ public class RestRequestTests extends ESTestCase { List headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10))); final RestRequest.ContentTypeHeaderException e = expectThrows( RestRequest.ContentTypeHeaderException.class, - () -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); + () -> contentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf((IllegalArgumentException.class))); assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided")); @@ -160,52 +161,64 @@ public class RestRequestTests extends ESTestCase { public void testRequiredContent() { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).requiredContent()); + contentRestRequest("", emptyMap()).requiredContent()); assertEquals("request body is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).requiredContent()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).requiredContent()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .requiredContent()); assertEquals("request body is required", e.getMessage()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("test", null, Collections.emptyMap()).requiredContent()); + contentRestRequest("test", null, Collections.emptyMap()).requiredContent()); assertEquals("unknown content type", e.getMessage()); } + private static RestRequest contentRestRequest(String content, Map params) { + Map> headers = new HashMap<>(); + headers.put("Content-Type", Collections.singletonList("application/json")); + return contentRestRequest(content, params, headers); + } + + private static RestRequest contentRestRequest(String content, Map params, Map> headers) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(headers); + builder.withContent(new BytesArray(content), null); + builder.withParams(params); + return new ContentRestRequest(builder.build()); + } + private static final class ContentRestRequest extends RestRequest { - private final BytesArray content; - ContentRestRequest(String content, Map params) { - this(content, params, Collections.singletonMap("Content-Type", Collections.singletonList("application/json"))); - } + private final RestRequest restRequest; - ContentRestRequest(String content, Map params, Map> headers) { - super(NamedXContentRegistry.EMPTY, params, "not used by this test", headers); - this.content = new BytesArray(content); - } - - @Override - public boolean hasContent() { - return Strings.hasLength(content); - } - - @Override - public BytesReference content() { - return content; - } - - @Override - public String uri() { - throw new UnsupportedOperationException("Not used by this test"); + private ContentRestRequest(RestRequest restRequest) { + super(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + this.restRequest = restRequest; } @Override public Method method() { - throw new UnsupportedOperationException("Not used by this test"); + return restRequest.method(); + } + + @Override + public String uri() { + return restRequest.uri(); + } + + @Override + public boolean hasContent() { + return Strings.hasLength(content()); + } + + @Override + public BytesReference content() { + return restRequest.content(); } } } diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java index a497e509c15..8cfec0a07f9 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import java.util.Collections; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import static org.elasticsearch.transport.RemoteClusterConnectionTests.startTransport; @@ -69,7 +70,6 @@ public class RemoteClusterClientTests extends ESTestCase { } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/29547") public void testEnsureWeReconnect() throws Exception { Settings remoteSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "foo_bar_cluster").build(); try (MockTransportService remoteTransport = startTransport("remote_node", Collections.emptyList(), Version.CURRENT, threadPool, @@ -79,17 +79,35 @@ public class RemoteClusterClientTests extends ESTestCase { .put(RemoteClusterService.ENABLE_REMOTE_CLUSTERS.getKey(), true) .put("search.remote.test.seeds", remoteNode.getAddress().getAddress() + ":" + remoteNode.getAddress().getPort()).build(); try (MockTransportService service = MockTransportService.createNewService(localSettings, Version.CURRENT, threadPool, null)) { + Semaphore semaphore = new Semaphore(1); service.start(); + service.addConnectionListener(new TransportConnectionListener() { + @Override + public void onNodeDisconnected(DiscoveryNode node) { + if (remoteNode.equals(node)) { + semaphore.release(); + } + } + }); + // this test is not perfect since we might reconnect concurrently but it will fail most of the time if we don't have + // the right calls in place in the RemoteAwareClient service.acceptIncomingRequests(); - service.disconnectFromNode(remoteNode); - RemoteClusterService remoteClusterService = service.getRemoteClusterService(); - assertBusy(() -> assertFalse(remoteClusterService.isRemoteNodeConnected("test", remoteNode))); - Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test"); - ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get(); - assertNotNull(clusterStateResponse); - assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value()); + for (int i = 0; i < 10; i++) { + semaphore.acquire(); + try { + service.disconnectFromNode(remoteNode); + semaphore.acquire(); + RemoteClusterService remoteClusterService = service.getRemoteClusterService(); + Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test"); + ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get(); + assertNotNull(clusterStateResponse); + assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value()); + assertTrue(remoteClusterService.isRemoteNodeConnected("test", remoteNode)); + } finally { + semaphore.release(); + } + } } } } - } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index d0403736400..4d4743156c7 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -19,12 +19,18 @@ package org.elasticsearch.test.rest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; -import java.net.SocketAddress; +import java.net.InetSocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -32,45 +38,115 @@ import java.util.Map; public class FakeRestRequest extends RestRequest { - private final BytesReference content; - private final Method method; - private final SocketAddress remoteAddress; - public FakeRestRequest() { - this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null); + this(NamedXContentRegistry.EMPTY, new FakeHttpRequest(Method.GET, "", BytesArray.EMPTY, new HashMap<>()), new HashMap<>(), + new FakeHttpChannel(null)); } - private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map> headers, - Map params, BytesReference content, Method method, String path, SocketAddress remoteAddress) { - super(xContentRegistry, params, path, headers); - this.content = content; - this.method = method; - this.remoteAddress = remoteAddress; - } - - @Override - public Method method() { - return method; - } - - @Override - public String uri() { - return rawPath(); + private FakeRestRequest(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, Map params, + HttpChannel httpChannel) { + super(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } @Override public boolean hasContent() { - return content != null; + return content() != null; } - @Override - public BytesReference content() { - return content; + private static class FakeHttpRequest implements HttpRequest { + + private final Method method; + private final String uri; + private final BytesReference content; + private final Map> headers; + + private FakeHttpRequest(Method method, String uri, BytesReference content, Map> headers) { + this.method = method; + this.uri = uri; + this.content = content; + this.headers = headers; + } + + @Override + public Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return content; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + return HttpVersion.HTTP_1_1; + } + + @Override + public HttpRequest removeHeader(String header) { + headers.remove(header); + return this; + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + Map headers = new HashMap<>(); + return new HttpResponse() { + @Override + public void addHeader(String name, String value) { + headers.put(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + }; + } } - @Override - public SocketAddress getRemoteAddress() { - return remoteAddress; + private static class FakeHttpChannel implements HttpChannel { + + private final InetSocketAddress remoteAddress; + + private FakeHttpChannel(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + + @Override + public void sendResponse(HttpResponse response, ActionListener listener) { + + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return remoteAddress; + } + + @Override + public void close() { + + } } public static class Builder { @@ -86,7 +162,7 @@ public class FakeRestRequest extends RestRequest { private Method method = Method.GET; - private SocketAddress address = null; + private InetSocketAddress address = null; public Builder(NamedXContentRegistry xContentRegistry) { this.xContentRegistry = xContentRegistry; @@ -120,15 +196,14 @@ public class FakeRestRequest extends RestRequest { return this; } - public Builder withRemoteAddress(SocketAddress address) { + public Builder withRemoteAddress(InetSocketAddress address) { this.address = address; return this; } public FakeRestRequest build() { - return new FakeRestRequest(xContentRegistry, headers, params, content, method, path, address); + FakeHttpRequest fakeHttpRequest = new FakeHttpRequest(method, path, content, headers); + return new FakeRestRequest(xContentRegistry, fakeHttpRequest, params, new FakeHttpChannel(address)); } - } - } diff --git a/x-pack/build.gradle b/x-pack/build.gradle index 91652b9e150..6a064ff5b7c 100644 --- a/x-pack/build.gradle +++ b/x-pack/build.gradle @@ -5,14 +5,6 @@ import org.elasticsearch.gradle.precommit.LicenseHeadersTask Project xpackRootProject = project -apply plugin: 'nebula.info-scm' -final String licenseCommit -if (version.endsWith('-SNAPSHOT')) { - licenseCommit = xpackRootProject.scminfo.change ?: "master" // leniency for non git builds -} else { - licenseCommit = "v${version}" -} - subprojects { group = 'org.elasticsearch.plugin' ext.xpackRootProject = xpackRootProject @@ -21,7 +13,7 @@ subprojects { ext.xpackModule = { String moduleName -> xpackProject("plugin:${moduleName}").path } ext.licenseName = 'Elastic License' - ext.licenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt" + ext.licenseUrl = ext.elasticLicenseUrl project.ext.licenseFile = rootProject.file('licenses/ELASTIC-LICENSE.txt') project.ext.noticeFile = xpackRootProject.file('NOTICE.txt') diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index de4d3ada51a..ac423c42811 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -43,9 +43,7 @@ subprojects { final FileCollection classDirectories = project.files(files).filter { it.exists() } doFirst { - String cp = project.configurations.featureAwarePlugin.asPath - cp = cp.replaceAll(":[^:]*/asm-debug-all-5.1.jar:", ":") - args('-cp', cp, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck') + args('-cp', project.configurations.featureAwarePlugin.asPath, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck') classDirectories.each { args it.getAbsolutePath() } } doLast { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java index de6ee3d509c..991f421265e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.job.config; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -30,6 +31,7 @@ public class MlFilter implements ToXContentObject, Writeable { public static final ParseField TYPE = new ParseField("type"); public static final ParseField ID = new ParseField("filter_id"); + public static final ParseField DESCRIPTION = new ParseField("description"); public static final ParseField ITEMS = new ParseField("items"); // For QueryPage @@ -43,27 +45,38 @@ public class MlFilter implements ToXContentObject, Writeable { parser.declareString((builder, s) -> {}, TYPE); parser.declareString(Builder::setId, ID); + parser.declareStringOrNull(Builder::setDescription, DESCRIPTION); parser.declareStringArray(Builder::setItems, ITEMS); return parser; } private final String id; + private final String description; private final List items; - public MlFilter(String id, List items) { + public MlFilter(String id, String description, List items) { this.id = Objects.requireNonNull(id, ID.getPreferredName() + " must not be null"); + this.description = description; this.items = Objects.requireNonNull(items, ITEMS.getPreferredName() + " must not be null"); } public MlFilter(StreamInput in) throws IOException { id = in.readString(); + if (in.getVersion().onOrAfter(Version.V_6_4_0)) { + description = in.readOptionalString(); + } else { + description = null; + } items = Arrays.asList(in.readStringArray()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); + if (out.getVersion().onOrAfter(Version.V_6_4_0)) { + out.writeOptionalString(description); + } out.writeStringArray(items.toArray(new String[items.size()])); } @@ -71,6 +84,9 @@ public class MlFilter implements ToXContentObject, Writeable { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ID.getPreferredName(), id); + if (description != null) { + builder.field(DESCRIPTION.getPreferredName(), description); + } builder.field(ITEMS.getPreferredName(), items); if (params.paramAsBoolean(MlMetaIndex.INCLUDE_TYPE_KEY, false)) { builder.field(TYPE.getPreferredName(), FILTER_TYPE); @@ -83,6 +99,10 @@ public class MlFilter implements ToXContentObject, Writeable { return id; } + public String getDescription() { + return description; + } + public List getItems() { return new ArrayList<>(items); } @@ -98,12 +118,12 @@ public class MlFilter implements ToXContentObject, Writeable { } MlFilter other = (MlFilter) obj; - return id.equals(other.id) && items.equals(other.items); + return id.equals(other.id) && Objects.equals(description, other.description) && items.equals(other.items); } @Override public int hashCode() { - return Objects.hash(id, items); + return Objects.hash(id, description, items); } public String documentId() { @@ -114,30 +134,45 @@ public class MlFilter implements ToXContentObject, Writeable { return DOCUMENT_ID_PREFIX + filterId; } + public static Builder builder(String filterId) { + return new Builder().setId(filterId); + } + public static class Builder { private String id; + private String description; private List items = Collections.emptyList(); + private Builder() {} + public Builder setId(String id) { this.id = id; return this; } - private Builder() {} - @Nullable public String getId() { return id; } + public Builder setDescription(String description) { + this.description = description; + return this; + } + public Builder setItems(List items) { this.items = items; return this; } + public Builder setItems(String... items) { + this.items = Arrays.asList(items); + return this; + } + public MlFilter build() { - return new MlFilter(id, items); + return new MlFilter(id, description, items); } } } \ No newline at end of file diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java index aec5b3a04d2..71424ec507f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.security.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -17,7 +16,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.rest.RestRequest; import java.io.IOException; -import java.net.SocketAddress; import java.util.Map; import java.util.Set; @@ -33,37 +31,15 @@ public interface RestRequestFilter { default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException { Set fields = getFilteredFields(); if (restRequest.hasContent() && fields.isEmpty() == false) { - return new RestRequest(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders()) { + return new RestRequest(restRequest) { private BytesReference filteredBytes = null; - @Override - public Method method() { - return restRequest.method(); - } - - @Override - public String uri() { - return restRequest.uri(); - } - @Override public boolean hasContent() { return true; } - @Nullable - @Override - public SocketAddress getRemoteAddress() { - return restRequest.getRemoteAddress(); - } - - @Nullable - @Override - public SocketAddress getLocalAddress() { - return restRequest.getLocalAddress(); - } - @Override public BytesReference content() { if (filteredBytes == null) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java index c8465c87587..7bda0f6e7de 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.test.AbstractStreamableTestCase; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction.Response; import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.job.config.MlFilter; +import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests; import java.util.Collections; @@ -17,9 +18,7 @@ public class GetFiltersActionResponseTests extends AbstractStreamableTestCase result; - - MlFilter doc = new MlFilter( - randomAlphaOfLengthBetween(1, 20), Collections.singletonList(randomAlphaOfLengthBetween(1, 20))); + MlFilter doc = MlFilterTests.createRandom(); result = new QueryPage<>(Collections.singletonList(doc), 1, MlFilter.RESULTS_FIELD); return new Response(result); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java index 21845922470..dfc3f5f37f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java @@ -8,10 +8,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractStreamableXContentTestCase; import org.elasticsearch.xpack.core.ml.action.PutFilterAction.Request; -import org.elasticsearch.xpack.core.ml.job.config.MlFilter; - -import java.util.ArrayList; -import java.util.List; +import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests; public class PutFilterActionRequestTests extends AbstractStreamableXContentTestCase { @@ -19,13 +16,7 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC @Override protected Request createTestInstance() { - int size = randomInt(10); - List items = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - items.add(randomAlphaOfLengthBetween(1, 20)); - } - MlFilter filter = new MlFilter(filterId, items); - return new PutFilterAction.Request(filter); + return new PutFilterAction.Request(MlFilterTests.createRandom(filterId)); } @Override @@ -42,5 +33,4 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC protected Request doParseInstance(XContentParser parser) { return PutFilterAction.Request.parseRequest(filterId, parser); } - } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java index 1b61e3ec9a4..78d87b82839 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java @@ -26,12 +26,25 @@ public class MlFilterTests extends AbstractSerializingTestCase { @Override protected MlFilter createTestInstance() { + return createRandom(); + } + + public static MlFilter createRandom() { + return createRandom(randomAlphaOfLengthBetween(1, 20)); + } + + public static MlFilter createRandom(String filterId) { + String description = null; + if (randomBoolean()) { + description = randomAlphaOfLength(20); + } + int size = randomInt(10); List items = new ArrayList<>(size); for (int i = 0; i < size; i++) { items.add(randomAlphaOfLengthBetween(1, 20)); } - return new MlFilter(randomAlphaOfLengthBetween(1, 20), items); + return new MlFilter(filterId, description, items); } @Override @@ -45,13 +58,13 @@ public class MlFilterTests extends AbstractSerializingTestCase { } public void testNullId() { - NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, Collections.emptyList())); + NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, "", Collections.emptyList())); assertEquals(MlFilter.ID.getPreferredName() + " must not be null", ex.getMessage()); } public void testNullItems() { NullPointerException ex = - expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), null)); + expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), "", null)); assertEquals(MlFilter.ITEMS.getPreferredName() + " must not be null", ex.getMessage()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java index 7e0dc453f07..856b930ac49 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java @@ -385,8 +385,8 @@ public class JobProviderIT extends MlSingleNodeTestCase { indexScheduledEvents(events); List filters = new ArrayList<>(); - filters.add(new MlFilter("fruit", Arrays.asList("apple", "pear"))); - filters.add(new MlFilter("tea", Arrays.asList("green", "builders"))); + filters.add(MlFilter.builder("fruit").setItems("apple", "pear").build()); + filters.add(MlFilter.builder("tea").setItems("green", "builders").build()); indexFilters(filters); DataCounts earliestCounts = DataCountsTests.createTestInstance(jobId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java index 454f941d6c8..42b0a56f49a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java @@ -210,7 +210,7 @@ public class JobManagerTests extends ESTestCase { JobManager jobManager = createJobManager(); - MlFilter filter = new MlFilter("foo_filter", Arrays.asList("a", "b")); + MlFilter filter = MlFilter.builder("foo_filter").setItems("a", "b").build(); jobManager.updateProcessOnFilterChanged(filter); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java index 8c32a5bb40d..3d08f5a1c25 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java @@ -207,8 +207,8 @@ public class ControlMsgToProcessWriterTests extends ESTestCase { public void testWriteUpdateFiltersMessage() throws IOException { ControlMsgToProcessWriter writer = new ControlMsgToProcessWriter(lengthEncodedWriter, 2); - MlFilter filter1 = new MlFilter("filter_1", Arrays.asList("a")); - MlFilter filter2 = new MlFilter("filter_2", Arrays.asList("b", "c")); + MlFilter filter1 = MlFilter.builder("filter_1").setItems("a").build(); + MlFilter filter2 = MlFilter.builder("filter_2").setItems("b", "c").build(); writer.writeUpdateFiltersMessage(Arrays.asList(filter1, filter2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java index bf08d09bf09..d26dbb203c8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java @@ -220,8 +220,8 @@ public class FieldConfigWriterTests extends ESTestCase { AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(d)); analysisConfig = builder.build(); - filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); - filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); + filters.add(MlFilter.builder("filter_1").setItems("a", "b").build()); + filters.add(MlFilter.builder("filter_2").setItems("c", "d").build()); writer = mock(OutputStreamWriter.class); createFieldConfigWriter().write(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java index f22f7d85090..12ceb12f462 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.xpack.core.ml.job.config.MlFilter; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -28,8 +27,8 @@ public class MlFilterWriterTests extends ESTestCase { public void testWrite() throws IOException { List filters = new ArrayList<>(); - filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); - filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); + filters.add(MlFilter.builder("filter_1").setItems("a", "b").build()); + filters.add(MlFilter.builder("filter_2").setItems("c", "d").build()); StringBuilder buffer = new StringBuilder(); new MlFilterWriter(filters, buffer).write(); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java index 1976722d65f..1991c2685f2 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java @@ -69,7 +69,6 @@ import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -829,10 +828,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } @@ -854,10 +852,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 3b9a42179a5..5706f79011a 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.security.transport.filter.SecurityIpFilterRule; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; import java.util.EnumSet; @@ -544,13 +543,8 @@ public class LoggingAuditTrail extends AbstractComponent implements AuditTrail, } private static String hostAttributes(RestRequest request) { - String formattedAddress; - final SocketAddress socketAddress = request.getRemoteAddress(); - if (socketAddress instanceof InetSocketAddress) { - formattedAddress = NetworkAddress.format(((InetSocketAddress) socketAddress).getAddress()); - } else { - formattedAddress = socketAddress.toString(); - } + final InetSocketAddress socketAddress = request.getHttpChannel().getRemoteAddress(); + String formattedAddress = NetworkAddress.format(socketAddress.getAddress()); return "origin_address=[" + formattedAddress + "]"; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java index dcee6535cf3..ed50a5cfe84 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java @@ -20,7 +20,7 @@ public class RemoteHostHeader { * then be copied to the subsequent action requests. */ public static void process(RestRequest request, ThreadContext threadContext) { - threadContext.putTransient(KEY, request.getRemoteAddress()); + threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress()); } /** diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index 0f4da8b847c..9109bb37e8c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.rest; +import io.netty.channel.Channel; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -13,7 +14,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.logging.ESLoggerFactory; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.netty4.Netty4HttpRequest; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.netty4.Netty4HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -50,10 +52,11 @@ public class SecurityRestFilter implements RestHandler { if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) { // CORS - allow for preflight unauthenticated OPTIONS request if (extractClientCertificate) { - Netty4HttpRequest nettyHttpRequest = (Netty4HttpRequest) request; - SslHandler handler = nettyHttpRequest.getChannel().pipeline().get(SslHandler.class); + HttpChannel httpChannel = request.getHttpChannel(); + Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel(); + SslHandler handler = nettyChannel.pipeline().get(SslHandler.class); assert handler != null; - ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel()); + ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyChannel); } service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap( authentication -> { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java index 01916b91380..ac586c49457 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java @@ -104,7 +104,7 @@ public class SecurityNetty4HttpServerTransport extends Netty4HttpServerTransport private final class HttpSslChannelHandler extends HttpChannelHandler { HttpSslChannelHandler() { - super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext()); + super(SecurityNetty4HttpServerTransport.this, handlingSettings); } @Override diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java index 7878fdb9233..2e2a931f78f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.plugins.MetaDataUpgrader; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestRequest; @@ -914,7 +915,9 @@ public class IndexAuditTrailTests extends SecurityIntegTestCase { private RestRequest mockRestRequest() { RestRequest request = mock(RestRequest.class); - when(request.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); + HttpChannel httpChannel = mock(HttpChannel.class); + when(request.getHttpChannel()).thenReturn(httpChannel); + when(httpChannel.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); when(request.uri()).thenReturn("_uri"); return request; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java index 335673f1c0c..127784dcfc0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java @@ -88,6 +88,6 @@ public class RestRequestFilterTests extends ESTestCase { new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON) .withRemoteAddress(address).build(); RestRequest filtered = filter.getFilteredRequest(restRequest); - assertEquals(address, filtered.getRemoteAddress()); + assertEquals(address, filtered.getHttpChannel().getRemoteAddress()); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java index 2857aee9b61..5db634c8d7b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -67,6 +68,7 @@ public class SecurityRestFilterTests extends ESTestCase { public void testProcess() throws Exception { RestRequest request = mock(RestRequest.class); + when(request.getHttpChannel()).thenReturn(mock(HttpChannel.class)); Authentication authentication = mock(Authentication.class); doAnswer((i) -> { ActionListener callback = diff --git a/x-pack/plugin/sql/build.gradle b/x-pack/plugin/sql/build.gradle index 8b406235985..19dd1a08ec6 100644 --- a/x-pack/plugin/sql/build.gradle +++ b/x-pack/plugin/sql/build.gradle @@ -20,7 +20,10 @@ integTest.enabled = false dependencies { compileOnly "org.elasticsearch.plugin:x-pack-core:${version}" - compileOnly project(':modules:lang-painless') + compileOnly(project(':modules:lang-painless')) { + // exclude ASM to not affect featureAware task on Java 10+ + exclude group: "org.ow2.asm" + } compile project('sql-proto') compile "org.elasticsearch.plugin:aggs-matrix-stats-client:${version}" compile "org.antlr:antlr4-runtime:4.5.3" diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml index d3165260f4b..a1f7eee0dcc 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml @@ -32,6 +32,7 @@ setup: filter_id: filter-foo2 body: > { + "description": "This filter has a description", "items": ["123", "lmnop"] } @@ -76,6 +77,7 @@ setup: - match: filters.1: filter_id: "filter-foo2" + description: "This filter has a description" items: ["123", "lmnop"] - do: diff --git a/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java b/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java index aa53d6255cb..b99170546df 100644 --- a/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java +++ b/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java @@ -120,7 +120,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { } public void testScope() throws Exception { - MlFilter safeIps = new MlFilter("safe_ips", Arrays.asList("111.111.111.111", "222.222.222.222")); + MlFilter safeIps = MlFilter.builder("safe_ips").setItems("111.111.111.111", "222.222.222.222").build(); assertThat(putMlFilter(safeIps), is(true)); DetectionRule rule = new DetectionRule.Builder(RuleScope.builder().include("ip", "safe_ips")).build(); @@ -178,7 +178,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { assertThat(records.get(0).getOverFieldValue(), equalTo("333.333.333.333")); // Now let's update the filter - MlFilter updatedFilter = new MlFilter(safeIps.getId(), Collections.singletonList("333.333.333.333")); + MlFilter updatedFilter = MlFilter.builder(safeIps.getId()).setItems("333.333.333.333").build(); assertThat(putMlFilter(updatedFilter), is(true)); // Wait until the notification that the process was updated is indexed @@ -229,7 +229,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { public void testScopeAndCondition() throws IOException { // We have 2 IPs and they're both safe-listed. List ips = Arrays.asList("111.111.111.111", "222.222.222.222"); - MlFilter safeIps = new MlFilter("safe_ips", ips); + MlFilter safeIps = MlFilter.builder("safe_ips").setItems(ips).build(); assertThat(putMlFilter(safeIps), is(true)); // Ignore if ip in safe list AND actual < 10.