From 43a48829518a24a8a4ee5d04c06a82acdb6e949e Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 8 Sep 2020 08:36:18 -0600 Subject: [PATCH] Move CorsHandler to server (#62007) Currently we duplicate our specialized cors logic in all transport plugins. This is unnecessary as it could be implemented in a single place. This commit moves the logic to server. Additionally it fixes a but where we are incorrectly closing http channels on early Cors responses. --- .../netty4/Netty4HttpServerTransport.java | 6 +- .../http/netty4/cors/Netty4CorsHandler.java | 253 -------------- .../http/netty4/Netty4CorsTests.java | 149 -------- .../http/nio/HttpReadWriteHandler.java | 8 +- .../http/nio/NioHttpServerTransport.java | 2 +- .../http/nio/cors/NioCorsHandler.java | 254 -------------- .../http/nio/HttpReadWriteHandlerTests.java | 161 +-------- .../http/AbstractHttpServerTransport.java | 28 +- .../org/elasticsearch/http/CorsHandler.java | 187 +++++++++- .../http/DefaultRestChannel.java | 23 +- .../org/elasticsearch/http/HttpRequest.java | 17 + .../org/elasticsearch/http/HttpUtils.java | 39 +++ .../org/elasticsearch/rest/RestUtils.java | 2 +- .../elasticsearch/http/CorsHandlerTests.java | 245 +++++++++++++- .../http/DefaultRestChannelTests.java | 318 +++++------------- .../elasticsearch/http/TestHttpRequest.java | 103 ++++++ .../elasticsearch/http/TestHttpResponse.java | 68 ++++ .../nio/SecurityNioHttpServerTransport.java | 2 +- 18 files changed, 772 insertions(+), 1093 deletions(-) delete mode 100644 modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java delete mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java delete mode 100644 plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpUtils.java create mode 100644 server/src/test/java/org/elasticsearch/http/TestHttpRequest.java create mode 100644 server/src/test/java/org/elasticsearch/http/TestHttpResponse.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 5fae4cb4bea..61a3dfd5f2a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -59,10 +59,9 @@ import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpReadTimeoutException; import org.elasticsearch.http.HttpServerChannel; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.SharedGroupFactory; import org.elasticsearch.transport.NettyAllocator; +import org.elasticsearch.transport.SharedGroupFactory; import org.elasticsearch.transport.netty4.Netty4Utils; import java.net.InetSocketAddress; @@ -315,9 +314,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel())); } ch.pipeline().addLast("request_creator", requestCreator); - if (handlingSettings.isCorsEnabled()) { - ch.pipeline().addLast("cors", new Netty4CorsHandler(transport.corsConfig)); - } ch.pipeline().addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents)); ch.pipeline().addLast("handler", requestHandler); transport.serverAcceptedChannel(nettyHttpChannel); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java deleted file mode 100644 index a89f01da6df..00000000000 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java +++ /dev/null @@ -1,253 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.netty4.cors; - -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import org.elasticsearch.common.Strings; -import org.elasticsearch.http.CorsHandler; -import org.elasticsearch.http.netty4.Netty4HttpRequest; -import org.elasticsearch.http.netty4.Netty4HttpResponse; - -import java.util.Date; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -/** - * Handles Cross Origin Resource Sharing (CORS) requests. - *

- * This handler can be configured using a {@link CorsHandler.Config}, please - * refer to this class for details about the configuration options available. - * - */ -public class Netty4CorsHandler extends ChannelDuplexHandler { - - public static final String ANY_ORIGIN = "*"; - private static Pattern SCHEME_PATTERN = Pattern.compile("^https?://"); - - private final CorsHandler.Config config; - private Netty4HttpRequest request; - - /** - * Creates a new instance with the specified {@link CorsHandler.Config}. - */ - public Netty4CorsHandler(final CorsHandler.Config config) { - if (config == null) { - throw new NullPointerException(); - } - this.config = config; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - assert msg instanceof Netty4HttpRequest : "Invalid message type: " + msg.getClass(); - if (config.isCorsSupportEnabled()) { - request = (Netty4HttpRequest) msg; - if (isPreflightRequest(request.nettyRequest())) { - try { - handlePreflight(ctx, request.nettyRequest()); - return; - } finally { - releaseRequest(); - } - } - if (!validateOrigin()) { - try { - forbidden(ctx, request.nettyRequest()); - return; - } finally { - releaseRequest(); - } - } - } - ctx.fireChannelRead(msg); - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass(); - Netty4HttpResponse response = (Netty4HttpResponse) msg; - setCorsResponseHeaders(response.requestHeaders(), response, config); - ctx.write(response, promise); - } - - public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) { - if (!config.isCorsSupportEnabled()) { - return; - } - String originHeader = headers.get(HttpHeaderNames.ORIGIN); - if (!Strings.isNullOrEmpty(originHeader)) { - final String originHeaderVal; - if (config.isAnyOriginSupported()) { - originHeaderVal = ANY_ORIGIN; - } else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) { - originHeaderVal = originHeader; - } else { - originHeaderVal = null; - } - if (originHeaderVal != null) { - resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal); - } - } - if (config.isCredentialsAllowed()) { - resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - } - } - - private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { - final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true); - if (setOrigin(response)) { - setAllowMethods(response); - setAllowHeaders(response); - setAllowCredentials(response); - setMaxAge(response); - setPreflightHeaders(response); - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); - } else { - forbidden(ctx, request); - } - } - - private void releaseRequest() { - request.release(); - request = null; - } - - private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { - ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN)) - .addListener(ChannelFutureListener.CLOSE); - } - - private static boolean isSameOrigin(final String origin, final String host) { - if (Strings.isNullOrEmpty(host) == false) { - // strip protocol from origin - final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst(""); - if (host.equals(originDomain)) { - return true; - } - } - return false; - } - - /** - * This is a non CORS specification feature which enables the setting of preflight - * response headers that might be required by intermediaries. - * - * @param response the HttpResponse to which the preflight response headers should be added. - */ - private void setPreflightHeaders(final HttpResponse response) { - response.headers().add("date", new Date()); - response.headers().add("content-length", "0"); - } - - private boolean setOrigin(final HttpResponse response) { - final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN); - if (!Strings.isNullOrEmpty(origin)) { - if (config.isAnyOriginSupported()) { - if (config.isCredentialsAllowed()) { - echoRequestOrigin(response); - setVaryHeader(response); - } else { - setAnyOrigin(response); - } - return true; - } - if (config.isOriginAllowed(origin)) { - setOrigin(response, origin); - setVaryHeader(response); - return true; - } - } - return false; - } - - private boolean validateOrigin() { - if (config.isAnyOriginSupported()) { - return true; - } - - final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN); - if (Strings.isNullOrEmpty(origin)) { - // Not a CORS request so we cannot validate it. It may be a non CORS request. - return true; - } - - // if the origin is the same as the host of the request, then allow - if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) { - return true; - } - - return config.isOriginAllowed(origin); - } - - private void echoRequestOrigin(final HttpResponse response) { - setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN)); - } - - private static void setVaryHeader(final HttpResponse response) { - response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN); - } - - private static void setAnyOrigin(final HttpResponse response) { - setOrigin(response, ANY_ORIGIN); - } - - private static void setOrigin(final HttpResponse response, final String origin) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin); - } - - private void setAllowCredentials(final HttpResponse response) { - if (config.isCredentialsAllowed() - && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - } - } - - private static boolean isPreflightRequest(final HttpRequest request) { - final HttpHeaders headers = request.headers(); - return request.method().equals(HttpMethod.OPTIONS) && - headers.contains(HttpHeaderNames.ORIGIN) && - headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD); - } - - private void setAllowMethods(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream() - .map(m -> m.name().trim()) - .collect(Collectors.toList())); - } - - private void setAllowHeaders(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders()); - } - - private void setMaxAge(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge()); - } - -} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java deleted file mode 100644 index 8a6b405fe9c..00000000000 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.netty4; - -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.http.CorsHandler; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; - -public class Netty4CorsTests extends ESTestCase { - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up an HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create an HTTP transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create an HTTP transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = Netty4CorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - embeddedChannel.pipeline().addLast(new Netty4CorsHandler(CorsHandler.fromSettings(settings))); - Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest); - embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content"))); - return embeddedChannel.readOutbound(); - } -} diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index 6d7e8f3ed8a..e161ee690f1 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -27,12 +27,10 @@ import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.http.CorsHandler; import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedResponse; import org.elasticsearch.http.HttpReadTimeoutException; -import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannelHandler; @@ -60,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler { private int inFlightRequests = 0; public HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, - CorsHandler.Config corsConfig, TaskScheduler taskScheduler, LongSupplier nanoClock) { + TaskScheduler taskScheduler, LongSupplier nanoClock) { this.nioHttpChannel = nioHttpChannel; this.transport = transport; this.taskScheduler = taskScheduler; @@ -79,9 +77,6 @@ public class HttpReadWriteHandler implements NioChannelHandler { handlers.add(new HttpContentCompressor(settings.getCompressionLevel())); } handlers.add(new NioHttpRequestCreator()); - if (settings.isCorsEnabled()) { - handlers.add(new NioCorsHandler(corsConfig)); - } handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents())); adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); @@ -150,7 +145,6 @@ public class HttpReadWriteHandler implements NioChannelHandler { } } - @SuppressWarnings("unchecked") private void handleRequest(Object msg) { final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg; boolean success = false; diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 8594dae39f2..f10d9538d9e 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -169,7 +169,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) { NioHttpChannel httpChannel = new NioHttpChannel(channel); HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this, - handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis); + handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInMillis); Consumer exceptionHandler = (e) -> onException(httpChannel, e); SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler, new InboundChannelBuffer(pageAllocator)); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java deleted file mode 100644 index 31ac5fb8640..00000000000 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.nio.cors; - -import io.netty.channel.ChannelDuplexHandler; -import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaders; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import org.elasticsearch.common.Strings; -import org.elasticsearch.http.CorsHandler; -import org.elasticsearch.http.nio.NioHttpRequest; -import org.elasticsearch.http.nio.NioHttpResponse; - -import java.util.Date; -import java.util.regex.Pattern; -import java.util.stream.Collectors; - -/** - * Handles Cross Origin Resource Sharing (CORS) requests. - *

- * This handler can be configured using a {@link CorsHandler.Config}, please - * refer to this class for details about the configuration options available. - * - * This code was borrowed from Netty 4 and refactored to work for Elasticsearch's Netty 3 setup. - */ -public class NioCorsHandler extends ChannelDuplexHandler { - - public static final String ANY_ORIGIN = "*"; - private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://"); - - private final CorsHandler.Config config; - private NioHttpRequest request; - - /** - * Creates a new instance with the specified {@link CorsHandler.Config}. - */ - public NioCorsHandler(final CorsHandler.Config config) { - if (config == null) { - throw new NullPointerException(); - } - this.config = config; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - assert msg instanceof NioHttpRequest : "Invalid message type: " + msg.getClass(); - if (config.isCorsSupportEnabled()) { - request = (NioHttpRequest) msg; - if (isPreflightRequest(request.nettyRequest())) { - try { - handlePreflight(ctx, request.nettyRequest()); - return; - } finally { - releaseRequest(); - } - } - if (!validateOrigin()) { - try { - forbidden(ctx, request.nettyRequest()); - return; - } finally { - releaseRequest(); - } - } - } - ctx.fireChannelRead(msg); - } - - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass(); - NioHttpResponse response = (NioHttpResponse) msg; - setCorsResponseHeaders(response.requestHeaders(), response, config); - ctx.write(response, promise); - } - - public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) { - if (!config.isCorsSupportEnabled()) { - return; - } - String originHeader = headers.get(HttpHeaderNames.ORIGIN); - if (!Strings.isNullOrEmpty(originHeader)) { - final String originHeaderVal; - if (config.isAnyOriginSupported()) { - originHeaderVal = ANY_ORIGIN; - } else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) { - originHeaderVal = originHeader; - } else { - originHeaderVal = null; - } - if (originHeaderVal != null) { - resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal); - } - } - if (config.isCredentialsAllowed()) { - resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - } - } - - private void releaseRequest() { - request.release(); - request = null; - } - - private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { - final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true); - if (setOrigin(response)) { - setAllowMethods(response); - setAllowHeaders(response); - setAllowCredentials(response); - setMaxAge(response); - setPreflightHeaders(response); - ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); - } else { - forbidden(ctx, request); - } - } - - private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { - ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN)) - .addListener(ChannelFutureListener.CLOSE); - } - - private static boolean isSameOrigin(final String origin, final String host) { - if (Strings.isNullOrEmpty(host) == false) { - // strip protocol from origin - final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst(""); - if (host.equals(originDomain)) { - return true; - } - } - return false; - } - - /** - * This is a non CORS specification feature which enables the setting of preflight - * response headers that might be required by intermediaries. - * - * @param response the HttpResponse to which the preflight response headers should be added. - */ - private void setPreflightHeaders(final HttpResponse response) { - response.headers().add("date", new Date()); - response.headers().add("content-length", "0"); - } - - private boolean setOrigin(final HttpResponse response) { - final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN); - if (!Strings.isNullOrEmpty(origin)) { - if (config.isAnyOriginSupported()) { - if (config.isCredentialsAllowed()) { - echoRequestOrigin(response); - setVaryHeader(response); - } else { - setAnyOrigin(response); - } - return true; - } - if (config.isOriginAllowed(origin)) { - setOrigin(response, origin); - setVaryHeader(response); - return true; - } - } - return false; - } - - private boolean validateOrigin() { - if (config.isAnyOriginSupported()) { - return true; - } - - final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN); - if (Strings.isNullOrEmpty(origin)) { - // Not a CORS request so we cannot validate it. It may be a non CORS request. - return true; - } - - // if the origin is the same as the host of the request, then allow - if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) { - return true; - } - - return config.isOriginAllowed(origin); - } - - private void echoRequestOrigin(final HttpResponse response) { - setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN)); - } - - private static void setVaryHeader(final HttpResponse response) { - response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN); - } - - private static void setAnyOrigin(final HttpResponse response) { - setOrigin(response, ANY_ORIGIN); - } - - private static void setOrigin(final HttpResponse response, final String origin) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin); - } - - private void setAllowCredentials(final HttpResponse response) { - if (config.isCredentialsAllowed() - && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); - } - } - - private static boolean isPreflightRequest(final HttpRequest request) { - final HttpHeaders headers = request.headers(); - return request.method().equals(HttpMethod.OPTIONS) && - headers.contains(HttpHeaderNames.ORIGIN) && - headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD); - } - - private void setAllowMethods(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream() - .map(m -> m.name().trim()) - .collect(Collectors.toList())); - } - - private void setAllowHeaders(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders()); - } - - private void setMaxAge(final HttpResponse response) { - response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge()); - } - -} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index e5974b53b95..dabef89e787 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -44,9 +44,6 @@ import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpPipelinedResponse; import org.elasticsearch.http.HttpReadTimeoutException; import org.elasticsearch.http.HttpRequest; -import org.elasticsearch.http.HttpResponse; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.SocketChannelContext; @@ -64,16 +61,8 @@ import java.util.Iterator; import java.util.List; import java.util.function.BiConsumer; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.atLeastOnce; @@ -104,8 +93,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { channel = mock(NioHttpChannel.class); taskScheduler = mock(TaskScheduler.class); - CorsHandler.Config corsConfig = CorsHandler.disabled(); - handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime); + handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, System::nanoTime); handler.channelActive(); } @@ -211,135 +199,17 @@ public class HttpReadWriteHandlerTests extends ESTestCase { } } - public void testCorsEnabledWithoutAllowOrigins() throws IOException { - // Set up an HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); - try { - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } finally { - response.release(); - } - } - - public void testCorsEnabledWithAllowOrigins() throws IOException { - final String originValue = "remote-host"; - // create an HTTP transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - try { - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } finally { - response.release(); - } - } - - public void testCorsAllowOriginWithSameHost() throws IOException { - String originValue = "remote-host"; - String host = "remote-host"; - // create an HTTP transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - FullHttpResponse response = executeCorsRequest(settings, originValue, host); - String allowedOrigins; - try { - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } finally { - response.release(); - } - originValue = "http://" + originValue; - response = executeCorsRequest(settings, originValue, host); - try { - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } finally { - response.release(); - } - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeCorsRequest(settings, originValue, host); - try { - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } finally { - response.release(); - } - originValue = originValue.replace("http", "https"); - response = executeCorsRequest(settings, originValue, host); - try { - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } finally { - response.release(); - } - } - - public void testThatStringLiteralWorksOnMatch() throws IOException { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - try { - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } finally { - response.release(); - } - } - - public void testThatAnyOriginWorks() throws IOException { - final String originValue = NioCorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host"); - try { - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } finally { - response.release(); - } - } - @SuppressWarnings("unchecked") public void testReadTimeout() throws IOException { TimeValue timeValue = TimeValue.timeValueMillis(500); Settings settings = Settings.builder().put(SETTING_HTTP_READ_TIMEOUT.getKey(), timeValue).build(); HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); - CorsHandler.Config corsConfig = CorsHandler.disabled(); + CorsHandler corsHandler = CorsHandler.disabled(); TaskScheduler taskScheduler = new TaskScheduler(); Iterator timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator(); - handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next); + handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, timeValues::next); handler.channelActive(); prepareHandlerForResponse(handler); @@ -382,31 +252,6 @@ public class HttpReadWriteHandlerTests extends ESTestCase { return httpResponse; } - private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException { - HttpHandlingSettings httpSettings = HttpHandlingSettings.fromSettings(settings); - CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings); - HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler, - System::nanoTime); - handler.channelActive(); - prepareHandlerForResponse(handler); - DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - HttpPipelinedRequest pipelinedRequest = new HttpPipelinedRequest(0, new NioHttpRequest(httpRequest)); - BytesArray content = new BytesArray("content"); - HttpResponse response = pipelinedRequest.createResponse(RestStatus.OK, content); - response.addHeader("Content-Length", Integer.toString(content.length())); - - SocketChannelContext context = mock(SocketChannelContext.class); - List flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {})); - handler.close(); - FlushOperation flushOperation = flushOperations.get(0); - ((ChannelPromise) flushOperation.getListener()).setSuccess(); - return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); - } - private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java index 204c9ad0365..af8095d6dec 100644 --- a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java @@ -67,6 +67,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_ public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport { private static final Logger logger = LogManager.getLogger(AbstractHttpServerTransport.class); + private static final ActionListener NO_OP = ActionListener.wrap(() -> {}); protected final Settings settings; public final HttpHandlingSettings handlingSettings; @@ -74,7 +75,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected final BigArrays bigArrays; protected final ThreadPool threadPool; protected final Dispatcher dispatcher; - protected final CorsHandler.Config corsConfig; + protected final CorsHandler corsHandler; private final NamedXContentRegistry xContentRegistry; protected final PortsRange port; @@ -98,7 +99,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo this.xContentRegistry = xContentRegistry; this.dispatcher = dispatcher; this.handlingSettings = HttpHandlingSettings.fromSettings(settings); - this.corsConfig = CorsHandler.fromSettings(settings); + this.corsHandler = CorsHandler.fromSettings(settings); // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); @@ -321,6 +322,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo } private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + if (exception == null) { + HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest); + if (earlyResponse != null) { + httpChannel.sendResponse(earlyResponse, earlyResponseListener(httpRequest, httpChannel)); + httpRequest.release(); + return; + } + } + Exception badRequestCause = exception; /* @@ -359,12 +369,14 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo ThreadContext threadContext = threadPool.getThreadContext(); try { innerChannel = - new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, trace); + new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, corsHandler, + trace); } catch (final IllegalArgumentException e) { badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); innerChannel = - new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, trace); + new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, corsHandler, + trace); } channel = innerChannel; } @@ -381,4 +393,12 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } } + + private static ActionListener earlyResponseListener(HttpRequest request, HttpChannel httpChannel) { + if (HttpUtils.shouldCloseConnection(request)) { + return ActionListener.wrap(() -> CloseableChannel.closeChannel(httpChannel)); + } else { + return NO_OP; + } + } } diff --git a/server/src/main/java/org/elasticsearch/http/CorsHandler.java b/server/src/main/java/org/elasticsearch/http/CorsHandler.java index 3dfde32322c..b3b0641ec49 100644 --- a/server/src/main/java/org/elasticsearch/http/CorsHandler.java +++ b/server/src/main/java/org/elasticsearch/http/CorsHandler.java @@ -35,16 +35,23 @@ package org.elasticsearch.http; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestUtils; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashSet; +import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.regex.Pattern; @@ -62,7 +69,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE; * files: io.netty.handler.codec.http.cors.CorsHandler, io.netty.handler.codec.http.cors.CorsConfig, and * io.netty.handler.codec.http.cors.CorsConfigBuilder. * - * It modifies the original netty code to operation on Elasticsearch http request/response abstractions. + * It modifies the original netty code to operate on Elasticsearch http request/response abstractions. * Additionally, it removes CORS features that are not used by Elasticsearch. */ public class CorsHandler { @@ -71,10 +78,172 @@ public class CorsHandler { public static final String ORIGIN = "origin"; public static final String DATE = "date"; public static final String VARY = "vary"; + public static final String HOST = "host"; public static final String ACCESS_CONTROL_REQUEST_METHOD = "access-control-request-method"; + public static final String ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"; + public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"; + public static final String ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"; public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"; + public static final String ACCESS_CONTROL_MAX_AGE = "access-control-max-age"; - private CorsHandler() { + private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://"); + private static final DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss O", Locale.ENGLISH); + private final Config config; + + public CorsHandler(Config config) { + this.config = config; + } + + public HttpResponse handleInbound(HttpRequest request) { + if (config.isCorsSupportEnabled()) { + if (isPreflightRequest(request)) { + return handlePreflight(request); + } + + if (validateOrigin(request) == false) { + return forbidden(request); + } + } + return null; + } + + public void setCorsResponseHeaders(final HttpRequest httpRequest, final HttpResponse httpResponse) { + if (!config.isCorsSupportEnabled()) { + return; + } + if (setOrigin(httpRequest, httpResponse)) { + setAllowCredentials(httpResponse); + } + } + + private HttpResponse handlePreflight(final HttpRequest request) { + final HttpResponse response = request.createResponse(RestStatus.OK, BytesArray.EMPTY); + if (setOrigin(request, response)) { + setAllowMethods(response); + setAllowHeaders(response); + setAllowCredentials(response); + setMaxAge(response); + setPreflightHeaders(response); + return response; + } else { + return forbidden(request); + } + } + + private static HttpResponse forbidden(final HttpRequest request) { + HttpResponse response = request.createResponse(RestStatus.FORBIDDEN, BytesArray.EMPTY); + response.addHeader("content-length", "0"); + return response; + } + + private static boolean isSameOrigin(final String origin, final String host) { + if (Strings.isNullOrEmpty(host) == false) { + // strip protocol from origin + final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst(""); + if (host.equals(originDomain)) { + return true; + } + } + return false; + } + + private void setPreflightHeaders(final HttpResponse response) { + response.addHeader(CorsHandler.DATE, dateTimeFormatter.format(ZonedDateTime.now(ZoneOffset.UTC))); + response.addHeader("content-length", "0"); + } + + private boolean setOrigin(final HttpRequest request, final HttpResponse response) { + String origin = getOrigin(request); + if (!Strings.isNullOrEmpty(origin)) { + if (config.isAnyOriginSupported()) { + if (config.isCredentialsAllowed()) { + setAllowOrigin(response, origin); + setVaryHeader(response); + } else { + setAllowOrigin(response, ANY_ORIGIN); + } + return true; + } else if (config.isOriginAllowed(origin) || isSameOrigin(origin, getHost(request))) { + setAllowOrigin(response, origin); + setVaryHeader(response); + return true; + } + } + return false; + } + + private boolean validateOrigin(final HttpRequest request) { + if (config.isAnyOriginSupported()) { + return true; + } + + final String origin = getOrigin(request); + if (Strings.isNullOrEmpty(origin)) { + // Not a CORS request so we cannot validate it. It may be a non CORS request. + return true; + } + + // if the origin is the same as the host of the request, then allow + if (isSameOrigin(origin, getHost(request))) { + return true; + } + + return config.isOriginAllowed(origin); + } + + private static String getOrigin(HttpRequest request) { + List headers = request.getHeaders().get(ORIGIN); + if (headers == null || headers.isEmpty()) { + return null; + } else { + return headers.get(0); + } + } + + private static String getHost(HttpRequest request) { + List headers = request.getHeaders().get(HOST); + if (headers == null || headers.isEmpty()) { + return null; + } else { + return headers.get(0); + } + } + + private static boolean isPreflightRequest(final HttpRequest request) { + final Map> headers = request.getHeaders(); + return request.method().equals(RestRequest.Method.OPTIONS) && + headers.containsKey(ORIGIN) && + headers.containsKey(ACCESS_CONTROL_REQUEST_METHOD); + } + + private static void setVaryHeader(final HttpResponse response) { + response.addHeader(VARY, ORIGIN); + } + + private static void setAllowOrigin(final HttpResponse response, final String origin) { + response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } + + private void setAllowMethods(final HttpResponse response) { + for (RestRequest.Method method : config.allowedRequestMethods()) { + response.addHeader(ACCESS_CONTROL_ALLOW_METHODS, method.name().trim()); + } + } + + private void setAllowHeaders(final HttpResponse response) { + for (String header : config.allowedRequestHeaders) { + response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS, header); + } + } + + private void setAllowCredentials(final HttpResponse response) { + if (config.isCredentialsAllowed()) { + response.addHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + } + } + + private void setMaxAge(final HttpResponse response) { + response.addHeader(ACCESS_CONTROL_MAX_AGE, Long.toString(config.maxAge)); } public static class Config { @@ -218,15 +387,17 @@ public class CorsHandler { } } - public static Config disabled() { + public static CorsHandler disabled() { Config.Builder builder = new Config.Builder(); builder.enabled = false; - return new Config(builder); + return new CorsHandler(new Config(builder)); } - public static Config fromSettings(Settings settings) { + public static Config buildConfig(Settings settings) { if (SETTING_CORS_ENABLED.get(settings) == false) { - return disabled(); + Config.Builder builder = new Config.Builder(); + builder.enabled = false; + return new Config(builder); } String origin = SETTING_CORS_ALLOW_ORIGIN.get(settings); final CorsHandler.Config.Builder builder; @@ -260,4 +431,8 @@ public class CorsHandler { .build(); return config; } + + public static CorsHandler fromSettings(Settings settings) { + return new CorsHandler(buildConfig(settings)); + } } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java index 369d78208c2..b2e7ad9e8ba 100644 --- a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -60,19 +60,21 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann private final HttpHandlingSettings settings; private final ThreadContext threadContext; private final HttpChannel httpChannel; + private final CorsHandler corsHandler; @Nullable private final HttpTracer tracerLog; DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays, - HttpHandlingSettings settings, ThreadContext threadContext, @Nullable HttpTracer tracerLog) { + HttpHandlingSettings settings, ThreadContext threadContext, CorsHandler corsHandler, + @Nullable HttpTracer tracerLog) { super(request, settings.getDetailedErrorsEnabled()); this.httpChannel = httpChannel; - // TODO: Fix this.httpRequest = httpRequest; this.bigArrays = bigArrays; this.settings = settings; this.threadContext = threadContext; + this.corsHandler = corsHandler; this.tracerLog = tracerLog; } @@ -87,7 +89,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann Releasables.closeWhileHandlingException(httpRequest::release); final ArrayList toClose = new ArrayList<>(3); - if (isCloseConnection()) { + if (HttpUtils.shouldCloseConnection(httpRequest)) { toClose.add(() -> CloseableChannel.closeChannel(httpChannel)); } @@ -112,8 +114,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann final HttpResponse httpResponse = httpRequest.createResponse(restResponse.status(), finalContent); - // TODO: Ideally we should move the setting of Cors headers into :server - // NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); + corsHandler.setCorsResponseHeaders(httpRequest, httpResponse); opaque = request.header(X_OPAQUE_ID); if (opaque != null) { @@ -180,16 +181,4 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann } } } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - try { - final boolean http10 = request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; - return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) - || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); - } catch (Exception e) { - // In case we fail to parse the http protocol version out of the request we always close the connection - return true; - } - } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java index 80b854fbe7d..e07032db861 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -58,6 +59,22 @@ public interface HttpRequest { */ Map> getHeaders(); + default String header(String name) { + List values = getHeaders().get(name); + if (values != null && values.isEmpty() == false) { + return values.get(0); + } + return null; + } + + default List allHeaders(String name) { + List values = getHeaders().get(name); + if (values != null) { + return Collections.unmodifiableList(values); + } + return null; + } + List strictCookies(); HttpVersion protocolVersion(); diff --git a/server/src/main/java/org/elasticsearch/http/HttpUtils.java b/server/src/main/java/org/elasticsearch/http/HttpUtils.java new file mode 100644 index 00000000000..dbb41790ca6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpUtils.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +public class HttpUtils { + + static final String CLOSE = "close"; + static final String CONNECTION = "connection"; + static final String KEEP_ALIVE = "keep-alive"; + + // Determine if the request connection should be closed on completion. + public static boolean shouldCloseConnection(HttpRequest httpRequest) { + try { + final boolean http10 = httpRequest.protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + return CLOSE.equalsIgnoreCase(httpRequest.header(CONNECTION)) + || (http10 && !KEEP_ALIVE.equalsIgnoreCase(httpRequest.header(CONNECTION))); + } catch (Exception e) { + // In case we fail to parse the http protocol version out of the request we always close the connection + return true; + } + } +} diff --git a/server/src/main/java/org/elasticsearch/rest/RestUtils.java b/server/src/main/java/org/elasticsearch/rest/RestUtils.java index 827174743f7..f1907a2bf18 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestUtils.java +++ b/server/src/main/java/org/elasticsearch/rest/RestUtils.java @@ -233,7 +233,7 @@ public class RestUtils { return null; } int len = corsSetting.length(); - boolean isRegex = len > 2 && corsSetting.startsWith("/") && corsSetting.endsWith("/"); + boolean isRegex = len > 2 && corsSetting.startsWith("/") && corsSetting.endsWith("/"); if (isRegex) { return Pattern.compile(corsSetting.substring(1, corsSetting.length()-1)); diff --git a/server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java b/server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java index 92ea69e0a3d..13c006bfd63 100644 --- a/server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/http/CorsHandlerTests.java @@ -20,15 +20,19 @@ package org.elasticsearch.http; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Locale; +import java.util.Map; import java.util.Set; import java.util.regex.PatternSyntaxException; import java.util.stream.Collectors; @@ -40,8 +44,11 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ME import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.nullValue; public class CorsHandlerTests extends ESTestCase { @@ -51,7 +58,7 @@ public class CorsHandlerTests extends ESTestCase { .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "/[*/") .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) .build(); - SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.fromSettings(settings)); + SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.buildConfig(settings)); assertThat(e.getMessage(), containsString("Bad regex in [http.cors.allow-origin]: [/[*/]")); assertThat(e.getCause(), instanceOf(PatternSyntaxException.class)); } @@ -67,7 +74,7 @@ public class CorsHandlerTests extends ESTestCase { .put(SETTING_CORS_ALLOW_HEADERS.getKey(), collectionToDelimitedString(headers, ",", prefix, "")) .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) .build(); - final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings); + final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings); assertTrue(corsConfig.isAnyOriginSupported()); assertEquals(headers, corsConfig.allowedRequestHeaders()); assertEquals(methods.stream().map(s -> s.toUpperCase(Locale.ENGLISH)).collect(Collectors.toSet()), @@ -79,7 +86,7 @@ public class CorsHandlerTests extends ESTestCase { final Set headers = Strings.commaDelimitedListToSet(SETTING_CORS_ALLOW_HEADERS.getDefault(Settings.EMPTY)); final long maxAge = SETTING_CORS_MAX_AGE.getDefault(Settings.EMPTY); final Settings settings = Settings.builder().put(SETTING_CORS_ENABLED.getKey(), true).build(); - final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings); + final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings); assertFalse(corsConfig.isAnyOriginSupported()); assertEquals(Collections.emptySet(), corsConfig.origins().get()); assertEquals(headers, corsConfig.allowedRequestHeaders()); @@ -87,4 +94,236 @@ public class CorsHandlerTests extends ESTestCase { assertEquals(maxAge, corsConfig.maxAge()); assertFalse(corsConfig.isCredentialsAllowed()); } + + public void testHandleInboundNonCorsRequest() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + HttpResponse httpResponse = corsHandler.handleInbound(request); + // Since this is not a Cors request, there is not an early response + assertThat(httpResponse, nullValue()); + } + + public void testHandleInboundValidCorsRequest() { + final String validOriginLiteral = "valid-origin"; + final String originSetting; + if (randomBoolean()) { + originSetting = validOriginLiteral; + } else { + if (randomBoolean()) { + originSetting = "/valid-.+/"; + } else { + originSetting = "*"; + } + } + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(validOriginLiteral)); + HttpResponse httpResponse = corsHandler.handleInbound(request); + // Since is a Cors enabled request. However, it is not forbidden because the origin is allowed. + assertThat(httpResponse, nullValue()); + } + + public void testHandleInboundForbidden() { + final String validOriginLiteral = "valid-origin"; + final String originSetting; + if (randomBoolean()) { + originSetting = validOriginLiteral; + } else { + originSetting = "/valid-.+/"; + } + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("invalid-origin")); + TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request); + // Forbidden + assertThat(httpResponse.status(), equalTo(RestStatus.FORBIDDEN)); + } + + public void testHandleInboundAllowsSameOrigin() { + final String validOriginLiteral = "valid-origin"; + final String originSetting; + if (randomBoolean()) { + originSetting = validOriginLiteral; + } else { + originSetting = "/valid-.+/"; + } + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("https://same-host")); + request.getHeaders().put(CorsHandler.HOST, Collections.singletonList("same-host")); + TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request); + // Since is a Cors enabled request. However, it is not forbidden because the origin is the same as the host. + assertThat(httpResponse, nullValue()); + } + + public void testHandleInboundPreflightWithWildcardNoCredentials() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE") + .put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length") + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST")); + TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request); + + assertThat(httpResponse.status(), equalTo(RestStatus.OK)); + Map> headers = httpResponse.headers(); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS), + containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS), + containsInAnyOrder("Content-Type", "Content-Length")); + assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000")); + assertNotNull(headers.get(CorsHandler.DATE)); + } + + public void testHandleInboundPreflightWithWildcardAllowCredentials() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST")); + TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request); + + assertThat(httpResponse.status(), equalTo(RestStatus.OK)); + Map> headers = httpResponse.headers(); + // Since credentials are allowed, we echo the origin + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin")); + assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN)); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS), + containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS), + containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000")); + assertNotNull(headers.get(CorsHandler.DATE)); + } + + public void testHandleInboundPreflightWithValidOriginAllowCredentials() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin") + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST")); + TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request); + + assertThat(httpResponse.status(), equalTo(RestStatus.OK)); + Map> headers = httpResponse.headers(); + // Since credentials are allowed, we echo the origin + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin")); + assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN)); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS), + containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS), + containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000")); + assertNotNull(headers.get(CorsHandler.DATE)); + } + + public void testSetResponseNonCorsRequest() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE") + .put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length") + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY); + corsHandler.setCorsResponseHeaders(request, response); + + Map> headers = response.headers(); + assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN)); + } + + public void testSetResponseHeadersWithWildcardOrigin() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY); + corsHandler.setCorsResponseHeaders(request, response); + + Map> headers = response.headers(); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*")); + assertNull(headers.get(CorsHandler.VARY)); + } + + public void testSetResponseHeadersWithCredentialsWithWildcard() { + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY); + corsHandler.setCorsResponseHeaders(request, response); + + Map> headers = response.headers(); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin")); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true")); + assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN)); + } + + public void testSetResponseHeadersWithNonWildcardOrigin() { + boolean allowCredentials = randomBoolean(); + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), allowCredentials) + .build(); + CorsHandler corsHandler = CorsHandler.fromSettings(settings); + + TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin")); + TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY); + corsHandler.setCorsResponseHeaders(request, response); + + Map> headers = response.headers(); + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin")); + assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN)); + if (allowCredentials) { + assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true")); + } else { + assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + } + } } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java index 3c0d0652f01..d633c2211b6 100644 --- a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -50,19 +50,17 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; import java.nio.channels.ClosedChannelException; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.Supplier; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -90,109 +88,72 @@ public class DefaultRestChannelTests extends ESTestCase { } public void testResponse() { - final TestResponse response = executeRequest(Settings.EMPTY, "request-host"); + final TestHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); assertThat(response.content(), equalTo(new TestRestResponse().content())); } - // TODO: Enable these Cors tests when the Cors logic lives in :server + public void testCorsEnabledWithoutAllowOrigins() { + // Set up an HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + TestHttpResponse response = executeRequest(settings, "request-host"); + assertThat(response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } -// public void testCorsEnabledWithoutAllowOrigins() { -// // Set up an HTTP transport with only the CORS enabled setting -// Settings settings = Settings.builder() -// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) -// .build(); -// HttpResponse response = executeRequest(settings, "remote-host", "request-host"); -// // inspect response and validate -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); -// } -// -// public void testCorsEnabledWithAllowOrigins() { -// final String originValue = "remote-host"; -// // create an HTTP transport with CORS enabled and allow origin configured -// Settings settings = Settings.builder() -// .put(SETTING_CORS_ENABLED.getKey(), true) -// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) -// .build(); -// HttpResponse response = executeRequest(settings, originValue, "request-host"); -// // inspect response and validate -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// } -// -// public void testCorsAllowOriginWithSameHost() { -// String originValue = "remote-host"; -// String host = "remote-host"; -// // create an HTTP transport with CORS enabled -// Settings settings = Settings.builder() -// .put(SETTING_CORS_ENABLED.getKey(), true) -// .build(); -// HttpResponse response = executeRequest(settings, originValue, host); -// // inspect response and validate -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// -// originValue = "http://" + originValue; -// response = executeRequest(settings, originValue, host); -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// -// originValue = originValue + ":5555"; -// host = host + ":5555"; -// response = executeRequest(settings, originValue, host); -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// -// originValue = originValue.replace("http", "https"); -// response = executeRequest(settings, originValue, host); -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// } -// -// public void testThatStringLiteralWorksOnMatch() { -// final String originValue = "remote-host"; -// Settings settings = Settings.builder() -// .put(SETTING_CORS_ENABLED.getKey(), true) -// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) -// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") -// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) -// .build(); -// HttpResponse response = executeRequest(settings, originValue, "request-host"); -// // inspect response and validate -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); -// } -// -// public void testThatAnyOriginWorks() { -// final String originValue = NioCorsHandler.ANY_ORIGIN; -// Settings settings = Settings.builder() -// .put(SETTING_CORS_ENABLED.getKey(), true) -// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) -// .build(); -// HttpResponse response = executeRequest(settings, originValue, "request-host"); -// // inspect response and validate -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); -// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); -// assertThat(allowedOrigins, is(originValue)); -// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); -// } + public void testCorsEnabledWithAllowOrigins() { + final String originValue = "remote-host"; + final String pattern; + if (randomBoolean()) { + pattern = originValue; + } else { + pattern = "/remote-hos.+/"; + } + // create an HTTP transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), pattern) + .build(); + TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1"); + assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0)); + assertThat(response.headers().get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN)); + } + + public void testCorsEnabledWithAllowOriginsAndAllowCredentials() { + final String originValue = "remote-host"; + // create an HTTP transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), CorsHandler.ANY_ORIGIN) + .put(HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1"); + assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0)); + assertEquals(CorsHandler.ORIGIN, response.headers().get(CorsHandler.VARY).get(0)); + assertEquals("true", response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS).get(0)); + } + + public void testThatAnyOriginWorks() { + final String originValue = CorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1"); + assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0)); + assertNull(response.headers().get(CorsHandler.VARY)); + } public void testHeadersSet() { Settings settings = Settings.builder().build(); - final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc")); final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); // send a response DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, - threadPool.getThreadContext(), null); + threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null); TestRestResponse resp = new TestRestResponse(); final String customHeader = "custom-header"; final String customHeaderValue = "xyz"; @@ -200,10 +161,10 @@ public class DefaultRestChannelTests extends ESTestCase { channel.sendResponse(resp); // inspect what was written - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class); verify(httpChannel).sendResponse(responseCaptor.capture(), any()); - TestResponse httpResponse = responseCaptor.getValue(); - Map> headers = httpResponse.headers; + TestHttpResponse httpResponse = responseCaptor.getValue(); + Map> headers = httpResponse.headers(); assertNull(headers.get("non-existent-header")); assertEquals(customHeaderValue, headers.get(customHeader).get(0)); assertEquals("abc", headers.get(Task.X_OPAQUE_ID).get(0)); @@ -213,21 +174,21 @@ public class DefaultRestChannelTests extends ESTestCase { public void testCookiesSet() { Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build(); - final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc")); final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); // send a response DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, - threadPool.getThreadContext(), null); + threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null); channel.sendResponse(new TestRestResponse()); // inspect what was written - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class); verify(httpChannel).sendResponse(responseCaptor.capture(), any()); - TestResponse nioResponse = responseCaptor.getValue(); - Map> headers = nioResponse.headers; + TestHttpResponse nioResponse = responseCaptor.getValue(); + Map> headers = nioResponse.headers(); assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie")); assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2")); } @@ -235,12 +196,12 @@ public class DefaultRestChannelTests extends ESTestCase { @SuppressWarnings("unchecked") public void testReleaseInListener() throws IOException { final Settings settings = Settings.builder().build(); - final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, - threadPool.getThreadContext(), null); + threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null); final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, JsonXContent.contentBuilder().startObject().endObject()); assertThat(response.content(), not(instanceOf(Releasable.class))); @@ -276,16 +237,16 @@ public class DefaultRestChannelTests extends ESTestCase { final boolean brokenRequest = randomBoolean(); final boolean close = brokenRequest || randomBoolean(); if (brokenRequest) { - httpRequest = new TestRequest(() -> { + httpRequest = new TestHttpRequest(() -> { throw new IllegalArgumentException("Can't parse HTTP version"); }, RestRequest.Method.GET, "/"); } else if (randomBoolean()) { - httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); if (close) { httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE)); } } else { - httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/"); + httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/"); if (!close) { httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE)); } @@ -295,7 +256,7 @@ public class DefaultRestChannelTests extends ESTestCase { HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, - threadPool.getThreadContext(), null); + threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null); channel.sendResponse(new TestRestResponse()); Class> listenerClass = (Class>) (Class) ActionListener.class; ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); @@ -317,7 +278,7 @@ public class DefaultRestChannelTests extends ESTestCase { final boolean close = randomBoolean(); final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1; final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE; - final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") { + final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") { @Override public RestRequest.Method method() { throw new IllegalArgumentException("test"); @@ -326,7 +287,8 @@ public class DefaultRestChannelTests extends ESTestCase { request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue)); DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays, - HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null); + HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY), + null); // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); @@ -354,7 +316,7 @@ public class DefaultRestChannelTests extends ESTestCase { final boolean close = randomBoolean(); final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1; final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE; - final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") { + final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") { @Override public HttpResponse createResponse(RestStatus status, BytesReference content) { throw new IllegalArgumentException("test"); @@ -363,7 +325,8 @@ public class DefaultRestChannelTests extends ESTestCase { request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue)); DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays, - HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null); + HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY), + null); // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); @@ -379,142 +342,29 @@ public class DefaultRestChannelTests extends ESTestCase { } } - private TestResponse executeRequest(final Settings settings, final String host) { + private TestHttpResponse executeRequest(final Settings settings, final String host) { return executeRequest(settings, null, host); } - private TestResponse executeRequest(final Settings settings, final String originValue, final String host) { - HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); - // TODO: These exist for the Cors tests -// if (originValue != null) { -// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); -// } -// httpRequest.headers().add(HttpHeaderNames.HOST, host); + private TestHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { + HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + if (originValue != null) { + httpRequest.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(originValue)); + } + httpRequest.getHeaders().put(CorsHandler.HOST, Collections.singletonList(host)); final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings, - threadPool.getThreadContext(), null); + threadPool.getThreadContext(), new CorsHandler(CorsHandler.buildConfig(settings)), null); channel.sendResponse(new TestRestResponse()); // get the response - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class); verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any()); return responseCaptor.getValue(); } - private static class TestRequest implements HttpRequest { - - private final Supplier version; - private final RestRequest.Method method; - private final String uri; - private HashMap> headers = new HashMap<>(); - - private TestRequest(Supplier versionSupplier, RestRequest.Method method, String uri) { - this.version = versionSupplier; - this.method = method; - this.uri = uri; - } - - private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { - this(() -> version, method, uri); - } - - @Override - public RestRequest.Method method() { - return method; - } - - @Override - public String uri() { - return uri; - } - - @Override - public BytesReference content() { - return BytesArray.EMPTY; - } - - @Override - public Map> getHeaders() { - return headers; - } - - @Override - public List strictCookies() { - return Arrays.asList("cookie", "cookie2"); - } - - @Override - public HttpVersion protocolVersion() { - return version.get(); - } - - @Override - public HttpRequest removeHeader(String header) { - throw new UnsupportedOperationException("Do not support removing header on test request."); - } - - @Override - public HttpResponse createResponse(RestStatus status, BytesReference content) { - return new TestResponse(status, content); - } - - @Override - public void release() { - } - - @Override - public HttpRequest releaseAndCopy() { - return this; - } - - @Override - public Exception getInboundException() { - return null; - } - } - - private static class TestResponse implements HttpResponse { - - private final RestStatus status; - private final BytesReference content; - private final Map> headers = new HashMap<>(); - - TestResponse(RestStatus status, BytesReference content) { - this.status = status; - this.content = content; - } - - public String contentType() { - return "text"; - } - - public BytesReference content() { - return content; - } - - public RestStatus status() { - return status; - } - - @Override - public void addHeader(String name, String value) { - if (headers.containsKey(name) == false) { - ArrayList values = new ArrayList<>(); - values.add(value); - headers.put(name, values); - } else { - headers.get(name).add(value); - } - } - - @Override - public boolean containsHeader(String name) { - return headers.containsKey(name); - } - } - private static class TestRestResponse extends RestResponse { private final RestStatus status; diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java new file mode 100644 index 00000000000..02921c1af10 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/http/TestHttpRequest.java @@ -0,0 +1,103 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +class TestHttpRequest implements HttpRequest { + + private final Supplier version; + private final RestRequest.Method method; + private final String uri; + private final HashMap> headers = new HashMap<>(); + + TestHttpRequest(Supplier versionSupplier, RestRequest.Method method, String uri) { + this.version = versionSupplier; + this.method = method; + this.uri = uri; + } + + TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri) { + this(() -> version, method, uri); + } + + @Override + public RestRequest.Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return BytesArray.EMPTY; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Arrays.asList("cookie", "cookie2"); + } + + @Override + public HttpVersion protocolVersion() { + return version.get(); + } + + @Override + public HttpRequest removeHeader(String header) { + throw new UnsupportedOperationException("Do not support removing header on test request."); + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + return new TestHttpResponse(status, content); + } + + @Override + public void release() { + } + + @Override + public HttpRequest releaseAndCopy() { + return this; + } + + @Override + public Exception getInboundException() { + return null; + } +} diff --git a/server/src/test/java/org/elasticsearch/http/TestHttpResponse.java b/server/src/test/java/org/elasticsearch/http/TestHttpResponse.java new file mode 100644 index 00000000000..e56ab0ef1df --- /dev/null +++ b/server/src/test/java/org/elasticsearch/http/TestHttpResponse.java @@ -0,0 +1,68 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.rest.RestStatus; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +class TestHttpResponse implements HttpResponse { + + private final RestStatus status; + private final BytesReference content; + private final Map> headers = new HashMap<>(); + + TestHttpResponse(RestStatus status, BytesReference content) { + this.status = status; + this.content = content; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return status; + } + + public Map> headers() { + return headers; + } + + @Override + public void addHeader(String name, String value) { + if (headers.containsKey(name) == false) { + ArrayList values = new ArrayList<>(); + values.add(value); + headers.put(name, values); + } else { + headers.get(name).add(value); + } + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index d8f8b803632..ff2c91da208 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -94,7 +94,7 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException { NioHttpChannel httpChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, - handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); + handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); final NioChannelHandler handler; if (ipFilter != null) { handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);