Move CorsHandler to server ()

Currently we duplicate our specialized cors logic in all transport
plugins. This is unnecessary as it could be implemented in a single
place. This commit moves the logic to server. Additionally it fixes a
but where we are incorrectly closing http channels on early Cors
responses.
This commit is contained in:
Tim Brooks 2020-09-08 08:36:18 -06:00
parent 9c0444145e
commit 43a4882951
No known key found for this signature in database
GPG Key ID: C2AA3BB91A889E77
18 changed files with 772 additions and 1093 deletions
modules/transport-netty4/src
main/java/org/elasticsearch/http/netty4
test/java/org/elasticsearch/http/netty4
plugins/transport-nio/src
main/java/org/elasticsearch/http/nio
test/java/org/elasticsearch/http/nio
server/src
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio

View File

@ -59,10 +59,9 @@ import org.elasticsearch.http.HttpChannel;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.HttpServerChannel;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.SharedGroupFactory;
import org.elasticsearch.transport.NettyAllocator;
import org.elasticsearch.transport.SharedGroupFactory;
import org.elasticsearch.transport.netty4.Netty4Utils;
import java.net.InetSocketAddress;
@ -315,9 +314,6 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport {
ch.pipeline().addLast("encoder_compress", new HttpContentCompressor(handlingSettings.getCompressionLevel()));
}
ch.pipeline().addLast("request_creator", requestCreator);
if (handlingSettings.isCorsEnabled()) {
ch.pipeline().addLast("cors", new Netty4CorsHandler(transport.corsConfig));
}
ch.pipeline().addLast("pipelining", new Netty4HttpPipeliningHandler(logger, transport.pipeliningMaxEvents));
ch.pipeline().addLast("handler", requestHandler);
transport.serverAcceptedChannel(nettyHttpChannel);

View File

@ -1,253 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4.cors;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings;
import org.elasticsearch.http.CorsHandler;
import org.elasticsearch.http.netty4.Netty4HttpRequest;
import org.elasticsearch.http.netty4.Netty4HttpResponse;
import java.util.Date;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
* <p>
* This handler can be configured using a {@link CorsHandler.Config}, please
* refer to this class for details about the configuration options available.
*
*/
public class Netty4CorsHandler extends ChannelDuplexHandler {
public static final String ANY_ORIGIN = "*";
private static Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
private final CorsHandler.Config config;
private Netty4HttpRequest request;
/**
* Creates a new instance with the specified {@link CorsHandler.Config}.
*/
public Netty4CorsHandler(final CorsHandler.Config config) {
if (config == null) {
throw new NullPointerException();
}
this.config = config;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof Netty4HttpRequest : "Invalid message type: " + msg.getClass();
if (config.isCorsSupportEnabled()) {
request = (Netty4HttpRequest) msg;
if (isPreflightRequest(request.nettyRequest())) {
try {
handlePreflight(ctx, request.nettyRequest());
return;
} finally {
releaseRequest();
}
}
if (!validateOrigin()) {
try {
forbidden(ctx, request.nettyRequest());
return;
} finally {
releaseRequest();
}
}
}
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass();
Netty4HttpResponse response = (Netty4HttpResponse) msg;
setCorsResponseHeaders(response.requestHeaders(), response, config);
ctx.write(response, promise);
}
public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) {
if (!config.isCorsSupportEnabled()) {
return;
}
String originHeader = headers.get(HttpHeaderNames.ORIGIN);
if (!Strings.isNullOrEmpty(originHeader)) {
final String originHeaderVal;
if (config.isAnyOriginSupported()) {
originHeaderVal = ANY_ORIGIN;
} else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) {
originHeaderVal = originHeader;
} else {
originHeaderVal = null;
}
if (originHeaderVal != null) {
resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal);
}
}
if (config.isCredentialsAllowed()) {
resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true);
if (setOrigin(response)) {
setAllowMethods(response);
setAllowHeaders(response);
setAllowCredentials(response);
setMaxAge(response);
setPreflightHeaders(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
forbidden(ctx, request);
}
}
private void releaseRequest() {
request.release();
request = null;
}
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE);
}
private static boolean isSameOrigin(final String origin, final String host) {
if (Strings.isNullOrEmpty(host) == false) {
// strip protocol from origin
final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
if (host.equals(originDomain)) {
return true;
}
}
return false;
}
/**
* This is a non CORS specification feature which enables the setting of preflight
* response headers that might be required by intermediaries.
*
* @param response the HttpResponse to which the preflight response headers should be added.
*/
private void setPreflightHeaders(final HttpResponse response) {
response.headers().add("date", new Date());
response.headers().add("content-length", "0");
}
private boolean setOrigin(final HttpResponse response) {
final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
if (!Strings.isNullOrEmpty(origin)) {
if (config.isAnyOriginSupported()) {
if (config.isCredentialsAllowed()) {
echoRequestOrigin(response);
setVaryHeader(response);
} else {
setAnyOrigin(response);
}
return true;
}
if (config.isOriginAllowed(origin)) {
setOrigin(response, origin);
setVaryHeader(response);
return true;
}
}
return false;
}
private boolean validateOrigin() {
if (config.isAnyOriginSupported()) {
return true;
}
final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
if (Strings.isNullOrEmpty(origin)) {
// Not a CORS request so we cannot validate it. It may be a non CORS request.
return true;
}
// if the origin is the same as the host of the request, then allow
if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) {
return true;
}
return config.isOriginAllowed(origin);
}
private void echoRequestOrigin(final HttpResponse response) {
setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN));
}
private static void setVaryHeader(final HttpResponse response) {
response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
}
private static void setAnyOrigin(final HttpResponse response) {
setOrigin(response, ANY_ORIGIN);
}
private static void setOrigin(final HttpResponse response, final String origin) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
private void setAllowCredentials(final HttpResponse response) {
if (config.isCredentialsAllowed()
&& !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private static boolean isPreflightRequest(final HttpRequest request) {
final HttpHeaders headers = request.headers();
return request.method().equals(HttpMethod.OPTIONS) &&
headers.contains(HttpHeaderNames.ORIGIN) &&
headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
}
private void setAllowMethods(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream()
.map(m -> m.name().trim())
.collect(Collectors.toList()));
}
private void setAllowHeaders(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
}
private void setMaxAge(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
}
}

View File

@ -1,149 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.netty4;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpVersion;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.http.CorsHandler;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.netty4.cors.Netty4CorsHandler;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
public class Netty4CorsTests extends ESTestCase {
public void testCorsEnabledWithoutAllowOrigins() {
// Set up an HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
// create an HTTP transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testCorsAllowOriginWithSameHost() {
String originValue = "remote-host";
String host = "remote-host";
// create an HTTP transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, host);
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = "http://" + originValue;
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue + ":5555";
host = host + ":5555";
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
originValue = originValue.replace("http", "https");
response = executeRequest(settings, originValue, host);
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
}
public void testThatStringLiteralWorksOnMatch() {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
}
public void testThatAnyOriginWorks() {
final String originValue = Netty4CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
HttpResponse response = executeRequest(settings, originValue, "request-host");
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
}
private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
// construct request and send it over the transport layer
final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
EmbeddedChannel embeddedChannel = new EmbeddedChannel();
embeddedChannel.pipeline().addLast(new Netty4CorsHandler(CorsHandler.fromSettings(settings)));
Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest);
embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content")));
return embeddedChannel.readOutbound();
}
}

View File

@ -27,12 +27,10 @@ import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.http.CorsHandler;
import org.elasticsearch.http.HttpHandlingSettings;
import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.HttpPipelinedResponse;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.NioChannelHandler;
@ -60,7 +58,7 @@ public class HttpReadWriteHandler implements NioChannelHandler {
private int inFlightRequests = 0;
public HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings,
CorsHandler.Config corsConfig, TaskScheduler taskScheduler, LongSupplier nanoClock) {
TaskScheduler taskScheduler, LongSupplier nanoClock) {
this.nioHttpChannel = nioHttpChannel;
this.transport = transport;
this.taskScheduler = taskScheduler;
@ -79,9 +77,6 @@ public class HttpReadWriteHandler implements NioChannelHandler {
handlers.add(new HttpContentCompressor(settings.getCompressionLevel()));
}
handlers.add(new NioHttpRequestCreator());
if (settings.isCorsEnabled()) {
handlers.add(new NioCorsHandler(corsConfig));
}
handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents()));
adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0]));
@ -150,7 +145,6 @@ public class HttpReadWriteHandler implements NioChannelHandler {
}
}
@SuppressWarnings("unchecked")
private void handleRequest(Object msg) {
final HttpPipelinedRequest pipelinedRequest = (HttpPipelinedRequest) msg;
boolean success = false;

View File

@ -169,7 +169,7 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport {
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler handler = new HttpReadWriteHandler(httpChannel,NioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInMillis);
Consumer<Exception> exceptionHandler = (e) -> onException(httpChannel, e);
SocketChannelContext context = new BytesChannelContext(httpChannel, selector, socketConfig, exceptionHandler, handler,
new InboundChannelBuffer(pageAllocator));

View File

@ -1,254 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http.nio.cors;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import org.elasticsearch.common.Strings;
import org.elasticsearch.http.CorsHandler;
import org.elasticsearch.http.nio.NioHttpRequest;
import org.elasticsearch.http.nio.NioHttpResponse;
import java.util.Date;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* Handles <a href="http://www.w3.org/TR/cors/">Cross Origin Resource Sharing</a> (CORS) requests.
* <p>
* This handler can be configured using a {@link CorsHandler.Config}, please
* refer to this class for details about the configuration options available.
*
* This code was borrowed from Netty 4 and refactored to work for Elasticsearch's Netty 3 setup.
*/
public class NioCorsHandler extends ChannelDuplexHandler {
public static final String ANY_ORIGIN = "*";
private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
private final CorsHandler.Config config;
private NioHttpRequest request;
/**
* Creates a new instance with the specified {@link CorsHandler.Config}.
*/
public NioCorsHandler(final CorsHandler.Config config) {
if (config == null) {
throw new NullPointerException();
}
this.config = config;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
assert msg instanceof NioHttpRequest : "Invalid message type: " + msg.getClass();
if (config.isCorsSupportEnabled()) {
request = (NioHttpRequest) msg;
if (isPreflightRequest(request.nettyRequest())) {
try {
handlePreflight(ctx, request.nettyRequest());
return;
} finally {
releaseRequest();
}
}
if (!validateOrigin()) {
try {
forbidden(ctx, request.nettyRequest());
return;
} finally {
releaseRequest();
}
}
}
ctx.fireChannelRead(msg);
}
@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass();
NioHttpResponse response = (NioHttpResponse) msg;
setCorsResponseHeaders(response.requestHeaders(), response, config);
ctx.write(response, promise);
}
public static void setCorsResponseHeaders(HttpHeaders headers, HttpResponse resp, CorsHandler.Config config) {
if (!config.isCorsSupportEnabled()) {
return;
}
String originHeader = headers.get(HttpHeaderNames.ORIGIN);
if (!Strings.isNullOrEmpty(originHeader)) {
final String originHeaderVal;
if (config.isAnyOriginSupported()) {
originHeaderVal = ANY_ORIGIN;
} else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, headers.get(HttpHeaderNames.HOST))) {
originHeaderVal = originHeader;
} else {
originHeaderVal = null;
}
if (originHeaderVal != null) {
resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, originHeaderVal);
}
}
if (config.isCredentialsAllowed()) {
resp.headers().add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private void releaseRequest() {
request.release();
request = null;
}
private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) {
final HttpResponse response = new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.OK, true, true);
if (setOrigin(response)) {
setAllowMethods(response);
setAllowHeaders(response);
setAllowCredentials(response);
setMaxAge(response);
setPreflightHeaders(response);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
forbidden(ctx, request);
}
}
private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) {
ctx.writeAndFlush(new DefaultFullHttpResponse(request.protocolVersion(), HttpResponseStatus.FORBIDDEN))
.addListener(ChannelFutureListener.CLOSE);
}
private static boolean isSameOrigin(final String origin, final String host) {
if (Strings.isNullOrEmpty(host) == false) {
// strip protocol from origin
final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
if (host.equals(originDomain)) {
return true;
}
}
return false;
}
/**
* This is a non CORS specification feature which enables the setting of preflight
* response headers that might be required by intermediaries.
*
* @param response the HttpResponse to which the preflight response headers should be added.
*/
private void setPreflightHeaders(final HttpResponse response) {
response.headers().add("date", new Date());
response.headers().add("content-length", "0");
}
private boolean setOrigin(final HttpResponse response) {
final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
if (!Strings.isNullOrEmpty(origin)) {
if (config.isAnyOriginSupported()) {
if (config.isCredentialsAllowed()) {
echoRequestOrigin(response);
setVaryHeader(response);
} else {
setAnyOrigin(response);
}
return true;
}
if (config.isOriginAllowed(origin)) {
setOrigin(response, origin);
setVaryHeader(response);
return true;
}
}
return false;
}
private boolean validateOrigin() {
if (config.isAnyOriginSupported()) {
return true;
}
final String origin = request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN);
if (Strings.isNullOrEmpty(origin)) {
// Not a CORS request so we cannot validate it. It may be a non CORS request.
return true;
}
// if the origin is the same as the host of the request, then allow
if (isSameOrigin(origin, request.nettyRequest().headers().get(HttpHeaderNames.HOST))) {
return true;
}
return config.isOriginAllowed(origin);
}
private void echoRequestOrigin(final HttpResponse response) {
setOrigin(response, request.nettyRequest().headers().get(HttpHeaderNames.ORIGIN));
}
private static void setVaryHeader(final HttpResponse response) {
response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN);
}
private static void setAnyOrigin(final HttpResponse response) {
setOrigin(response, ANY_ORIGIN);
}
private static void setOrigin(final HttpResponse response, final String origin) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
private void setAllowCredentials(final HttpResponse response) {
if (config.isCredentialsAllowed()
&& !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private static boolean isPreflightRequest(final HttpRequest request) {
final HttpHeaders headers = request.headers();
return request.method().equals(HttpMethod.OPTIONS) &&
headers.contains(HttpHeaderNames.ORIGIN) &&
headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
}
private void setAllowMethods(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods().stream()
.map(m -> m.name().trim())
.collect(Collectors.toList()));
}
private void setAllowHeaders(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders());
}
private void setMaxAge(final HttpResponse response) {
response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge());
}
}

View File

@ -44,9 +44,6 @@ import org.elasticsearch.http.HttpPipelinedRequest;
import org.elasticsearch.http.HttpPipelinedResponse;
import org.elasticsearch.http.HttpReadTimeoutException;
import org.elasticsearch.http.HttpRequest;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.http.HttpTransportSettings;
import org.elasticsearch.http.nio.cors.NioCorsHandler;
import org.elasticsearch.nio.FlushOperation;
import org.elasticsearch.nio.InboundChannelBuffer;
import org.elasticsearch.nio.SocketChannelContext;
@ -64,16 +61,8 @@ import java.util.Iterator;
import java.util.List;
import java.util.function.BiConsumer;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_CONTENT_LENGTH;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_READ_TIMEOUT;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.atLeastOnce;
@ -104,8 +93,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
channel = mock(NioHttpChannel.class);
taskScheduler = mock(TaskScheduler.class);
CorsHandler.Config corsConfig = CorsHandler.disabled();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, System::nanoTime);
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, System::nanoTime);
handler.channelActive();
}
@ -211,135 +199,17 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
}
}
public void testCorsEnabledWithoutAllowOrigins() throws IOException {
// Set up an HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, "remote-host", "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
} finally {
response.release();
}
}
public void testCorsEnabledWithAllowOrigins() throws IOException {
final String originValue = "remote-host";
// create an HTTP transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}
public void testCorsAllowOriginWithSameHost() throws IOException {
String originValue = "remote-host";
String host = "remote-host";
// create an HTTP transport with CORS enabled
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, host);
String allowedOrigins;
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = "http://" + originValue;
response = executeCorsRequest(settings, originValue, host);
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = originValue + ":5555";
host = host + ":5555";
response = executeCorsRequest(settings, originValue, host);
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
originValue = originValue.replace("http", "https");
response = executeCorsRequest(settings, originValue, host);
try {
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
} finally {
response.release();
}
}
public void testThatStringLiteralWorksOnMatch() throws IOException {
final String originValue = "remote-host";
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
} finally {
response.release();
}
}
public void testThatAnyOriginWorks() throws IOException {
final String originValue = NioCorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
FullHttpResponse response = executeCorsRequest(settings, originValue, "request-host");
try {
// inspect response and validate
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
assertThat(allowedOrigins, is(originValue));
assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
} finally {
response.release();
}
}
@SuppressWarnings("unchecked")
public void testReadTimeout() throws IOException {
TimeValue timeValue = TimeValue.timeValueMillis(500);
Settings settings = Settings.builder().put(SETTING_HTTP_READ_TIMEOUT.getKey(), timeValue).build();
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
CorsHandler.Config corsConfig = CorsHandler.disabled();
CorsHandler corsHandler = CorsHandler.disabled();
TaskScheduler taskScheduler = new TaskScheduler();
Iterator<Integer> timeValues = Arrays.asList(0, 2, 4, 6, 8).iterator();
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, corsConfig, taskScheduler, timeValues::next);
handler = new HttpReadWriteHandler(channel, transport, httpHandlingSettings, taskScheduler, timeValues::next);
handler.channelActive();
prepareHandlerForResponse(handler);
@ -382,31 +252,6 @@ public class HttpReadWriteHandlerTests extends ESTestCase {
return httpResponse;
}
private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException {
HttpHandlingSettings httpSettings = HttpHandlingSettings.fromSettings(settings);
CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
HttpReadWriteHandler handler = new HttpReadWriteHandler(channel, transport, httpSettings, corsConfig, taskScheduler,
System::nanoTime);
handler.channelActive();
prepareHandlerForResponse(handler);
DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
if (originValue != null) {
httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
}
httpRequest.headers().add(HttpHeaderNames.HOST, host);
HttpPipelinedRequest pipelinedRequest = new HttpPipelinedRequest(0, new NioHttpRequest(httpRequest));
BytesArray content = new BytesArray("content");
HttpResponse response = pipelinedRequest.createResponse(RestStatus.OK, content);
response.addHeader("Content-Length", Integer.toString(content.length()));
SocketChannelContext context = mock(SocketChannelContext.class);
List<FlushOperation> flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {}));
handler.close();
FlushOperation flushOperation = flushOperations.get(0);
((ChannelPromise) flushOperation.getListener()).setSuccess();
return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite()));
}
private void prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException {

View File

@ -67,6 +67,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_
public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport {
private static final Logger logger = LogManager.getLogger(AbstractHttpServerTransport.class);
private static final ActionListener<Void> NO_OP = ActionListener.wrap(() -> {});
protected final Settings settings;
public final HttpHandlingSettings handlingSettings;
@ -74,7 +75,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
protected final BigArrays bigArrays;
protected final ThreadPool threadPool;
protected final Dispatcher dispatcher;
protected final CorsHandler.Config corsConfig;
protected final CorsHandler corsHandler;
private final NamedXContentRegistry xContentRegistry;
protected final PortsRange port;
@ -98,7 +99,7 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
this.xContentRegistry = xContentRegistry;
this.dispatcher = dispatcher;
this.handlingSettings = HttpHandlingSettings.fromSettings(settings);
this.corsConfig = CorsHandler.fromSettings(settings);
this.corsHandler = CorsHandler.fromSettings(settings);
// we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here
List<String> httpBindHost = SETTING_HTTP_BIND_HOST.get(settings);
@ -321,6 +322,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
}
private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) {
if (exception == null) {
HttpResponse earlyResponse = corsHandler.handleInbound(httpRequest);
if (earlyResponse != null) {
httpChannel.sendResponse(earlyResponse, earlyResponseListener(httpRequest, httpChannel));
httpRequest.release();
return;
}
}
Exception badRequestCause = exception;
/*
@ -359,12 +369,14 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
ThreadContext threadContext = threadPool.getThreadContext();
try {
innerChannel =
new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, trace);
new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext, corsHandler,
trace);
} catch (final IllegalArgumentException e) {
badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e);
final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel);
innerChannel =
new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, trace);
new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext, corsHandler,
trace);
}
channel = innerChannel;
}
@ -381,4 +393,12 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo
return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel);
}
}
private static ActionListener<Void> earlyResponseListener(HttpRequest request, HttpChannel httpChannel) {
if (HttpUtils.shouldCloseConnection(request)) {
return ActionListener.wrap(() -> CloseableChannel.closeChannel(httpChannel));
} else {
return NO_OP;
}
}
}

View File

@ -35,16 +35,23 @@
package org.elasticsearch.http;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.rest.RestUtils;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
@ -62,7 +69,7 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE;
* files: io.netty.handler.codec.http.cors.CorsHandler, io.netty.handler.codec.http.cors.CorsConfig, and
* io.netty.handler.codec.http.cors.CorsConfigBuilder.
*
* It modifies the original netty code to operation on Elasticsearch http request/response abstractions.
* It modifies the original netty code to operate on Elasticsearch http request/response abstractions.
* Additionally, it removes CORS features that are not used by Elasticsearch.
*/
public class CorsHandler {
@ -71,10 +78,172 @@ public class CorsHandler {
public static final String ORIGIN = "origin";
public static final String DATE = "date";
public static final String VARY = "vary";
public static final String HOST = "host";
public static final String ACCESS_CONTROL_REQUEST_METHOD = "access-control-request-method";
public static final String ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers";
public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials";
public static final String ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods";
public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin";
public static final String ACCESS_CONTROL_MAX_AGE = "access-control-max-age";
private CorsHandler() {
private static final Pattern SCHEME_PATTERN = Pattern.compile("^https?://");
private static final DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss O", Locale.ENGLISH);
private final Config config;
public CorsHandler(Config config) {
this.config = config;
}
public HttpResponse handleInbound(HttpRequest request) {
if (config.isCorsSupportEnabled()) {
if (isPreflightRequest(request)) {
return handlePreflight(request);
}
if (validateOrigin(request) == false) {
return forbidden(request);
}
}
return null;
}
public void setCorsResponseHeaders(final HttpRequest httpRequest, final HttpResponse httpResponse) {
if (!config.isCorsSupportEnabled()) {
return;
}
if (setOrigin(httpRequest, httpResponse)) {
setAllowCredentials(httpResponse);
}
}
private HttpResponse handlePreflight(final HttpRequest request) {
final HttpResponse response = request.createResponse(RestStatus.OK, BytesArray.EMPTY);
if (setOrigin(request, response)) {
setAllowMethods(response);
setAllowHeaders(response);
setAllowCredentials(response);
setMaxAge(response);
setPreflightHeaders(response);
return response;
} else {
return forbidden(request);
}
}
private static HttpResponse forbidden(final HttpRequest request) {
HttpResponse response = request.createResponse(RestStatus.FORBIDDEN, BytesArray.EMPTY);
response.addHeader("content-length", "0");
return response;
}
private static boolean isSameOrigin(final String origin, final String host) {
if (Strings.isNullOrEmpty(host) == false) {
// strip protocol from origin
final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst("");
if (host.equals(originDomain)) {
return true;
}
}
return false;
}
private void setPreflightHeaders(final HttpResponse response) {
response.addHeader(CorsHandler.DATE, dateTimeFormatter.format(ZonedDateTime.now(ZoneOffset.UTC)));
response.addHeader("content-length", "0");
}
private boolean setOrigin(final HttpRequest request, final HttpResponse response) {
String origin = getOrigin(request);
if (!Strings.isNullOrEmpty(origin)) {
if (config.isAnyOriginSupported()) {
if (config.isCredentialsAllowed()) {
setAllowOrigin(response, origin);
setVaryHeader(response);
} else {
setAllowOrigin(response, ANY_ORIGIN);
}
return true;
} else if (config.isOriginAllowed(origin) || isSameOrigin(origin, getHost(request))) {
setAllowOrigin(response, origin);
setVaryHeader(response);
return true;
}
}
return false;
}
private boolean validateOrigin(final HttpRequest request) {
if (config.isAnyOriginSupported()) {
return true;
}
final String origin = getOrigin(request);
if (Strings.isNullOrEmpty(origin)) {
// Not a CORS request so we cannot validate it. It may be a non CORS request.
return true;
}
// if the origin is the same as the host of the request, then allow
if (isSameOrigin(origin, getHost(request))) {
return true;
}
return config.isOriginAllowed(origin);
}
private static String getOrigin(HttpRequest request) {
List<String> headers = request.getHeaders().get(ORIGIN);
if (headers == null || headers.isEmpty()) {
return null;
} else {
return headers.get(0);
}
}
private static String getHost(HttpRequest request) {
List<String> headers = request.getHeaders().get(HOST);
if (headers == null || headers.isEmpty()) {
return null;
} else {
return headers.get(0);
}
}
private static boolean isPreflightRequest(final HttpRequest request) {
final Map<String, List<String>> headers = request.getHeaders();
return request.method().equals(RestRequest.Method.OPTIONS) &&
headers.containsKey(ORIGIN) &&
headers.containsKey(ACCESS_CONTROL_REQUEST_METHOD);
}
private static void setVaryHeader(final HttpResponse response) {
response.addHeader(VARY, ORIGIN);
}
private static void setAllowOrigin(final HttpResponse response, final String origin) {
response.addHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
}
private void setAllowMethods(final HttpResponse response) {
for (RestRequest.Method method : config.allowedRequestMethods()) {
response.addHeader(ACCESS_CONTROL_ALLOW_METHODS, method.name().trim());
}
}
private void setAllowHeaders(final HttpResponse response) {
for (String header : config.allowedRequestHeaders) {
response.addHeader(ACCESS_CONTROL_ALLOW_HEADERS, header);
}
}
private void setAllowCredentials(final HttpResponse response) {
if (config.isCredentialsAllowed()) {
response.addHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
}
}
private void setMaxAge(final HttpResponse response) {
response.addHeader(ACCESS_CONTROL_MAX_AGE, Long.toString(config.maxAge));
}
public static class Config {
@ -218,15 +387,17 @@ public class CorsHandler {
}
}
public static Config disabled() {
public static CorsHandler disabled() {
Config.Builder builder = new Config.Builder();
builder.enabled = false;
return new Config(builder);
return new CorsHandler(new Config(builder));
}
public static Config fromSettings(Settings settings) {
public static Config buildConfig(Settings settings) {
if (SETTING_CORS_ENABLED.get(settings) == false) {
return disabled();
Config.Builder builder = new Config.Builder();
builder.enabled = false;
return new Config(builder);
}
String origin = SETTING_CORS_ALLOW_ORIGIN.get(settings);
final CorsHandler.Config.Builder builder;
@ -260,4 +431,8 @@ public class CorsHandler {
.build();
return config;
}
public static CorsHandler fromSettings(Settings settings) {
return new CorsHandler(buildConfig(settings));
}
}

View File

@ -60,19 +60,21 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
private final HttpHandlingSettings settings;
private final ThreadContext threadContext;
private final HttpChannel httpChannel;
private final CorsHandler corsHandler;
@Nullable
private final HttpTracer tracerLog;
DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays,
HttpHandlingSettings settings, ThreadContext threadContext, @Nullable HttpTracer tracerLog) {
HttpHandlingSettings settings, ThreadContext threadContext, CorsHandler corsHandler,
@Nullable HttpTracer tracerLog) {
super(request, settings.getDetailedErrorsEnabled());
this.httpChannel = httpChannel;
// TODO: Fix
this.httpRequest = httpRequest;
this.bigArrays = bigArrays;
this.settings = settings;
this.threadContext = threadContext;
this.corsHandler = corsHandler;
this.tracerLog = tracerLog;
}
@ -87,7 +89,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
Releasables.closeWhileHandlingException(httpRequest::release);
final ArrayList<Releasable> toClose = new ArrayList<>(3);
if (isCloseConnection()) {
if (HttpUtils.shouldCloseConnection(httpRequest)) {
toClose.add(() -> CloseableChannel.closeChannel(httpChannel));
}
@ -112,8 +114,7 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
final HttpResponse httpResponse = httpRequest.createResponse(restResponse.status(), finalContent);
// TODO: Ideally we should move the setting of Cors headers into :server
// NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig);
corsHandler.setCorsResponseHeaders(httpRequest, httpResponse);
opaque = request.header(X_OPAQUE_ID);
if (opaque != null) {
@ -180,16 +181,4 @@ public class DefaultRestChannel extends AbstractRestChannel implements RestChann
}
}
}
// Determine if the request connection should be closed on completion.
private boolean isCloseConnection() {
try {
final boolean http10 = request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
return CLOSE.equalsIgnoreCase(request.header(CONNECTION))
|| (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION)));
} catch (Exception e) {
// In case we fail to parse the http protocol version out of the request we always close the connection
return true;
}
}
}

View File

@ -24,6 +24,7 @@ import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -58,6 +59,22 @@ public interface HttpRequest {
*/
Map<String, List<String>> getHeaders();
default String header(String name) {
List<String> values = getHeaders().get(name);
if (values != null && values.isEmpty() == false) {
return values.get(0);
}
return null;
}
default List<String> allHeaders(String name) {
List<String> values = getHeaders().get(name);
if (values != null) {
return Collections.unmodifiableList(values);
}
return null;
}
List<String> strictCookies();
HttpVersion protocolVersion();

View File

@ -0,0 +1,39 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
public class HttpUtils {
static final String CLOSE = "close";
static final String CONNECTION = "connection";
static final String KEEP_ALIVE = "keep-alive";
// Determine if the request connection should be closed on completion.
public static boolean shouldCloseConnection(HttpRequest httpRequest) {
try {
final boolean http10 = httpRequest.protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0;
return CLOSE.equalsIgnoreCase(httpRequest.header(CONNECTION))
|| (http10 && !KEEP_ALIVE.equalsIgnoreCase(httpRequest.header(CONNECTION)));
} catch (Exception e) {
// In case we fail to parse the http protocol version out of the request we always close the connection
return true;
}
}
}

View File

@ -233,7 +233,7 @@ public class RestUtils {
return null;
}
int len = corsSetting.length();
boolean isRegex = len > 2 && corsSetting.startsWith("/") && corsSetting.endsWith("/");
boolean isRegex = len > 2 && corsSetting.startsWith("/") && corsSetting.endsWith("/");
if (isRegex) {
return Pattern.compile(corsSetting.substring(1, corsSetting.length()-1));

View File

@ -20,15 +20,19 @@
package org.elasticsearch.http;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsException;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.regex.PatternSyntaxException;
import java.util.stream.Collectors;
@ -40,8 +44,11 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ME
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED;
import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_MAX_AGE;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.nullValue;
public class CorsHandlerTests extends ESTestCase {
@ -51,7 +58,7 @@ public class CorsHandlerTests extends ESTestCase {
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "/[*/")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.fromSettings(settings));
SettingsException e = expectThrows(SettingsException.class, () -> CorsHandler.buildConfig(settings));
assertThat(e.getMessage(), containsString("Bad regex in [http.cors.allow-origin]: [/[*/]"));
assertThat(e.getCause(), instanceOf(PatternSyntaxException.class));
}
@ -67,7 +74,7 @@ public class CorsHandlerTests extends ESTestCase {
.put(SETTING_CORS_ALLOW_HEADERS.getKey(), collectionToDelimitedString(headers, ",", prefix, ""))
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings);
assertTrue(corsConfig.isAnyOriginSupported());
assertEquals(headers, corsConfig.allowedRequestHeaders());
assertEquals(methods.stream().map(s -> s.toUpperCase(Locale.ENGLISH)).collect(Collectors.toSet()),
@ -79,7 +86,7 @@ public class CorsHandlerTests extends ESTestCase {
final Set<String> headers = Strings.commaDelimitedListToSet(SETTING_CORS_ALLOW_HEADERS.getDefault(Settings.EMPTY));
final long maxAge = SETTING_CORS_MAX_AGE.getDefault(Settings.EMPTY);
final Settings settings = Settings.builder().put(SETTING_CORS_ENABLED.getKey(), true).build();
final CorsHandler.Config corsConfig = CorsHandler.fromSettings(settings);
final CorsHandler.Config corsConfig = CorsHandler.buildConfig(settings);
assertFalse(corsConfig.isAnyOriginSupported());
assertEquals(Collections.emptySet(), corsConfig.origins().get());
assertEquals(headers, corsConfig.allowedRequestHeaders());
@ -87,4 +94,236 @@ public class CorsHandlerTests extends ESTestCase {
assertEquals(maxAge, corsConfig.maxAge());
assertFalse(corsConfig.isCredentialsAllowed());
}
public void testHandleInboundNonCorsRequest() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
HttpResponse httpResponse = corsHandler.handleInbound(request);
// Since this is not a Cors request, there is not an early response
assertThat(httpResponse, nullValue());
}
public void testHandleInboundValidCorsRequest() {
final String validOriginLiteral = "valid-origin";
final String originSetting;
if (randomBoolean()) {
originSetting = validOriginLiteral;
} else {
if (randomBoolean()) {
originSetting = "/valid-.+/";
} else {
originSetting = "*";
}
}
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(validOriginLiteral));
HttpResponse httpResponse = corsHandler.handleInbound(request);
// Since is a Cors enabled request. However, it is not forbidden because the origin is allowed.
assertThat(httpResponse, nullValue());
}
public void testHandleInboundForbidden() {
final String validOriginLiteral = "valid-origin";
final String originSetting;
if (randomBoolean()) {
originSetting = validOriginLiteral;
} else {
originSetting = "/valid-.+/";
}
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("invalid-origin"));
TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
// Forbidden
assertThat(httpResponse.status(), equalTo(RestStatus.FORBIDDEN));
}
public void testHandleInboundAllowsSameOrigin() {
final String validOriginLiteral = "valid-origin";
final String originSetting;
if (randomBoolean()) {
originSetting = validOriginLiteral;
} else {
originSetting = "/valid-.+/";
}
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originSetting)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.POST, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("https://same-host"));
request.getHeaders().put(CorsHandler.HOST, Collections.singletonList("same-host"));
TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
// Since is a Cors enabled request. However, it is not forbidden because the origin is the same as the host.
assertThat(httpResponse, nullValue());
}
public void testHandleInboundPreflightWithWildcardNoCredentials() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE")
.put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length")
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
assertThat(httpResponse.status(), equalTo(RestStatus.OK));
Map<String, List<String>> headers = httpResponse.headers();
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
containsInAnyOrder("Content-Type", "Content-Length"));
assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
assertNotNull(headers.get(CorsHandler.DATE));
}
public void testHandleInboundPreflightWithWildcardAllowCredentials() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
assertThat(httpResponse.status(), equalTo(RestStatus.OK));
Map<String, List<String>> headers = httpResponse.headers();
// Since credentials are allowed, we echo the origin
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
assertNotNull(headers.get(CorsHandler.DATE));
}
public void testHandleInboundPreflightWithValidOriginAllowCredentials() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin")
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE,POST")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.OPTIONS, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
request.getHeaders().put(CorsHandler.ACCESS_CONTROL_REQUEST_METHOD, Collections.singletonList("POST"));
TestHttpResponse httpResponse = (TestHttpResponse) corsHandler.handleInbound(request);
assertThat(httpResponse.status(), equalTo(RestStatus.OK));
Map<String, List<String>> headers = httpResponse.headers();
// Since credentials are allowed, we echo the origin
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_METHODS),
containsInAnyOrder("HEAD", "OPTIONS", "GET", "DELETE", "POST"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_HEADERS),
containsInAnyOrder("X-Requested-With", "Content-Type", "Content-Length"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_MAX_AGE), containsInAnyOrder("1728000"));
assertNotNull(headers.get(CorsHandler.DATE));
}
public void testSetResponseNonCorsRequest() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
.put(SETTING_CORS_ALLOW_METHODS.getKey(), "OPTIONS,HEAD,GET,DELETE")
.put(SETTING_CORS_ALLOW_HEADERS.getKey(), "Content-Type,Content-Length")
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
corsHandler.setCorsResponseHeaders(request, response);
Map<String, List<String>> headers = response.headers();
assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN));
}
public void testSetResponseHeadersWithWildcardOrigin() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
corsHandler.setCorsResponseHeaders(request, response);
Map<String, List<String>> headers = response.headers();
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("*"));
assertNull(headers.get(CorsHandler.VARY));
}
public void testSetResponseHeadersWithCredentialsWithWildcard() {
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "*")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
corsHandler.setCorsResponseHeaders(request, response);
Map<String, List<String>> headers = response.headers();
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
}
public void testSetResponseHeadersWithNonWildcardOrigin() {
boolean allowCredentials = randomBoolean();
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(SETTING_CORS_ALLOW_ORIGIN.getKey(), "valid-origin")
.put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), allowCredentials)
.build();
CorsHandler corsHandler = CorsHandler.fromSettings(settings);
TestHttpRequest request = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
request.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList("valid-origin"));
TestHttpResponse response = new TestHttpResponse(RestStatus.OK, BytesArray.EMPTY);
corsHandler.setCorsResponseHeaders(request, response);
Map<String, List<String>> headers = response.headers();
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), containsInAnyOrder("valid-origin"));
assertThat(headers.get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
if (allowCredentials) {
assertThat(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS), containsInAnyOrder("true"));
} else {
assertNull(headers.get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS));
}
}
}

View File

@ -50,19 +50,17 @@ import org.mockito.ArgumentCaptor;
import java.io.IOException;
import java.nio.channels.ClosedChannelException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
@ -90,109 +88,72 @@ public class DefaultRestChannelTests extends ESTestCase {
}
public void testResponse() {
final TestResponse response = executeRequest(Settings.EMPTY, "request-host");
final TestHttpResponse response = executeRequest(Settings.EMPTY, "request-host");
assertThat(response.content(), equalTo(new TestRestResponse().content()));
}
// TODO: Enable these Cors tests when the Cors logic lives in :server
public void testCorsEnabledWithoutAllowOrigins() {
// Set up an HTTP transport with only the CORS enabled setting
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.build();
TestHttpResponse response = executeRequest(settings, "request-host");
assertThat(response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
}
// public void testCorsEnabledWithoutAllowOrigins() {
// // Set up an HTTP transport with only the CORS enabled setting
// Settings settings = Settings.builder()
// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, "remote-host", "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue());
// }
//
// public void testCorsEnabledWithAllowOrigins() {
// final String originValue = "remote-host";
// // create an HTTP transport with CORS enabled and allow origin configured
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testCorsAllowOriginWithSameHost() {
// String originValue = "remote-host";
// String host = "remote-host";
// // create an HTTP transport with CORS enabled
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, host);
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = "http://" + originValue;
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue + ":5555";
// host = host + ":5555";
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
//
// originValue = originValue.replace("http", "https");
// response = executeRequest(settings, originValue, host);
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// }
//
// public void testThatStringLiteralWorksOnMatch() {
// final String originValue = "remote-host";
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post")
// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true"));
// }
//
// public void testThatAnyOriginWorks() {
// final String originValue = NioCorsHandler.ANY_ORIGIN;
// Settings settings = Settings.builder()
// .put(SETTING_CORS_ENABLED.getKey(), true)
// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
// .build();
// HttpResponse response = executeRequest(settings, originValue, "request-host");
// // inspect response and validate
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue());
// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN);
// assertThat(allowedOrigins, is(originValue));
// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue());
// }
public void testCorsEnabledWithAllowOrigins() {
final String originValue = "remote-host";
final String pattern;
if (randomBoolean()) {
pattern = originValue;
} else {
pattern = "/remote-hos.+/";
}
// create an HTTP transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), pattern)
.build();
TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
assertThat(response.headers().get(CorsHandler.VARY), containsInAnyOrder(CorsHandler.ORIGIN));
}
public void testCorsEnabledWithAllowOriginsAndAllowCredentials() {
final String originValue = "remote-host";
// create an HTTP transport with CORS enabled and allow origin configured
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), CorsHandler.ANY_ORIGIN)
.put(HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true)
.build();
TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
assertEquals(CorsHandler.ORIGIN, response.headers().get(CorsHandler.VARY).get(0));
assertEquals("true", response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_CREDENTIALS).get(0));
}
public void testThatAnyOriginWorks() {
final String originValue = CorsHandler.ANY_ORIGIN;
Settings settings = Settings.builder()
.put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true)
.put(HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue)
.build();
TestHttpResponse response = executeRequest(settings, originValue, "https://127.0.0.1");
assertEquals(originValue, response.headers().get(CorsHandler.ACCESS_CONTROL_ALLOW_ORIGIN).get(0));
assertNull(response.headers().get(CorsHandler.VARY));
}
public void testHeadersSet() {
Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext(), null);
threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
TestRestResponse resp = new TestRestResponse();
final String customHeader = "custom-header";
final String customHeaderValue = "xyz";
@ -200,10 +161,10 @@ public class DefaultRestChannelTests extends ESTestCase {
channel.sendResponse(resp);
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse httpResponse = responseCaptor.getValue();
Map<String, List<String>> headers = httpResponse.headers;
TestHttpResponse httpResponse = responseCaptor.getValue();
Map<String, List<String>> headers = httpResponse.headers();
assertNull(headers.get("non-existent-header"));
assertEquals(customHeaderValue, headers.get(customHeader).get(0));
assertEquals("abc", headers.get(Task.X_OPAQUE_ID).get(0));
@ -213,21 +174,21 @@ public class DefaultRestChannelTests extends ESTestCase {
public void testCookiesSet() {
Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest.getHeaders().put(Task.X_OPAQUE_ID, Collections.singletonList("abc"));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
// send a response
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext(), null);
threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
channel.sendResponse(new TestRestResponse());
// inspect what was written
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
verify(httpChannel).sendResponse(responseCaptor.capture(), any());
TestResponse nioResponse = responseCaptor.getValue();
Map<String, List<String>> headers = nioResponse.headers;
TestHttpResponse nioResponse = responseCaptor.getValue();
Map<String, List<String>> headers = nioResponse.headers();
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie"));
assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2"));
}
@ -235,12 +196,12 @@ public class DefaultRestChannelTests extends ESTestCase {
@SuppressWarnings("unchecked")
public void testReleaseInListener() throws IOException {
final Settings settings = Settings.builder().build();
final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final TestHttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext(), null);
threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR,
JsonXContent.contentBuilder().startObject().endObject());
assertThat(response.content(), not(instanceOf(Releasable.class)));
@ -276,16 +237,16 @@ public class DefaultRestChannelTests extends ESTestCase {
final boolean brokenRequest = randomBoolean();
final boolean close = brokenRequest || randomBoolean();
if (brokenRequest) {
httpRequest = new TestRequest(() -> {
httpRequest = new TestHttpRequest(() -> {
throw new IllegalArgumentException("Can't parse HTTP version");
}, RestRequest.Method.GET, "/");
} else if (randomBoolean()) {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
if (close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE));
}
} else {
httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/");
if (!close) {
httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE));
}
@ -295,7 +256,7 @@ public class DefaultRestChannelTests extends ESTestCase {
HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings);
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings,
threadPool.getThreadContext(), null);
threadPool.getThreadContext(), CorsHandler.fromSettings(settings), null);
channel.sendResponse(new TestRestResponse());
Class<ActionListener<Void>> listenerClass = (Class<ActionListener<Void>>) (Class) ActionListener.class;
ArgumentCaptor<ActionListener<Void>> listenerCaptor = ArgumentCaptor.forClass(listenerClass);
@ -317,7 +278,7 @@ public class DefaultRestChannelTests extends ESTestCase {
final boolean close = randomBoolean();
final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1;
final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE;
final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") {
final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") {
@Override
public RestRequest.Method method() {
throw new IllegalArgumentException("test");
@ -326,7 +287,8 @@ public class DefaultRestChannelTests extends ESTestCase {
request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue));
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays,
HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null);
HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY),
null);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
@ -354,7 +316,7 @@ public class DefaultRestChannelTests extends ESTestCase {
final boolean close = randomBoolean();
final HttpRequest.HttpVersion httpVersion = close ? HttpRequest.HttpVersion.HTTP_1_0 : HttpRequest.HttpVersion.HTTP_1_1;
final String httpConnectionHeaderValue = close ? DefaultRestChannel.CLOSE : DefaultRestChannel.KEEP_ALIVE;
final RestRequest request = RestRequest.request(xContentRegistry(), new TestRequest(httpVersion, null, "/") {
final RestRequest request = RestRequest.request(xContentRegistry(), new TestHttpRequest(httpVersion, null, "/") {
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
throw new IllegalArgumentException("test");
@ -363,7 +325,8 @@ public class DefaultRestChannelTests extends ESTestCase {
request.getHttpRequest().getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(httpConnectionHeaderValue));
DefaultRestChannel channel = new DefaultRestChannel(httpChannel, request.getHttpRequest(), request, bigArrays,
HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), null);
HttpHandlingSettings.fromSettings(Settings.EMPTY), threadPool.getThreadContext(), CorsHandler.fromSettings(Settings.EMPTY),
null);
// ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released
final BigArrays bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService());
@ -379,142 +342,29 @@ public class DefaultRestChannelTests extends ESTestCase {
}
}
private TestResponse executeRequest(final Settings settings, final String host) {
private TestHttpResponse executeRequest(final Settings settings, final String host) {
return executeRequest(settings, null, host);
}
private TestResponse executeRequest(final Settings settings, final String originValue, final String host) {
HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
// TODO: These exist for the Cors tests
// if (originValue != null) {
// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue);
// }
// httpRequest.headers().add(HttpHeaderNames.HOST, host);
private TestHttpResponse executeRequest(final Settings settings, final String originValue, final String host) {
HttpRequest httpRequest = new TestHttpRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/");
if (originValue != null) {
httpRequest.getHeaders().put(CorsHandler.ORIGIN, Collections.singletonList(originValue));
}
httpRequest.getHeaders().put(CorsHandler.HOST, Collections.singletonList(host));
final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel);
HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);
RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings,
threadPool.getThreadContext(), null);
threadPool.getThreadContext(), new CorsHandler(CorsHandler.buildConfig(settings)), null);
channel.sendResponse(new TestRestResponse());
// get the response
ArgumentCaptor<TestResponse> responseCaptor = ArgumentCaptor.forClass(TestResponse.class);
ArgumentCaptor<TestHttpResponse> responseCaptor = ArgumentCaptor.forClass(TestHttpResponse.class);
verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any());
return responseCaptor.getValue();
}
private static class TestRequest implements HttpRequest {
private final Supplier<HttpVersion> version;
private final RestRequest.Method method;
private final String uri;
private HashMap<String, List<String>> headers = new HashMap<>();
private TestRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
this.version = versionSupplier;
this.method = method;
this.uri = uri;
}
private TestRequest(HttpVersion version, RestRequest.Method method, String uri) {
this(() -> version, method, uri);
}
@Override
public RestRequest.Method method() {
return method;
}
@Override
public String uri() {
return uri;
}
@Override
public BytesReference content() {
return BytesArray.EMPTY;
}
@Override
public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Arrays.asList("cookie", "cookie2");
}
@Override
public HttpVersion protocolVersion() {
return version.get();
}
@Override
public HttpRequest removeHeader(String header) {
throw new UnsupportedOperationException("Do not support removing header on test request.");
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return new TestResponse(status, content);
}
@Override
public void release() {
}
@Override
public HttpRequest releaseAndCopy() {
return this;
}
@Override
public Exception getInboundException() {
return null;
}
}
private static class TestResponse implements HttpResponse {
private final RestStatus status;
private final BytesReference content;
private final Map<String, List<String>> headers = new HashMap<>();
TestResponse(RestStatus status, BytesReference content) {
this.status = status;
this.content = content;
}
public String contentType() {
return "text";
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return status;
}
@Override
public void addHeader(String name, String value) {
if (headers.containsKey(name) == false) {
ArrayList<String> values = new ArrayList<>();
values.add(value);
headers.put(name, values);
} else {
headers.get(name).add(value);
}
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
}
private static class TestRestResponse extends RestResponse {
private final RestStatus status;

View File

@ -0,0 +1,103 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
class TestHttpRequest implements HttpRequest {
private final Supplier<HttpVersion> version;
private final RestRequest.Method method;
private final String uri;
private final HashMap<String, List<String>> headers = new HashMap<>();
TestHttpRequest(Supplier<HttpVersion> versionSupplier, RestRequest.Method method, String uri) {
this.version = versionSupplier;
this.method = method;
this.uri = uri;
}
TestHttpRequest(HttpVersion version, RestRequest.Method method, String uri) {
this(() -> version, method, uri);
}
@Override
public RestRequest.Method method() {
return method;
}
@Override
public String uri() {
return uri;
}
@Override
public BytesReference content() {
return BytesArray.EMPTY;
}
@Override
public Map<String, List<String>> getHeaders() {
return headers;
}
@Override
public List<String> strictCookies() {
return Arrays.asList("cookie", "cookie2");
}
@Override
public HttpVersion protocolVersion() {
return version.get();
}
@Override
public HttpRequest removeHeader(String header) {
throw new UnsupportedOperationException("Do not support removing header on test request.");
}
@Override
public HttpResponse createResponse(RestStatus status, BytesReference content) {
return new TestHttpResponse(status, content);
}
@Override
public void release() {
}
@Override
public HttpRequest releaseAndCopy() {
return this;
}
@Override
public Exception getInboundException() {
return null;
}
}

View File

@ -0,0 +1,68 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.http;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.rest.RestStatus;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
class TestHttpResponse implements HttpResponse {
private final RestStatus status;
private final BytesReference content;
private final Map<String, List<String>> headers = new HashMap<>();
TestHttpResponse(RestStatus status, BytesReference content) {
this.status = status;
this.content = content;
}
public BytesReference content() {
return content;
}
public RestStatus status() {
return status;
}
public Map<String, List<String>> headers() {
return headers;
}
@Override
public void addHeader(String name, String value) {
if (headers.containsKey(name) == false) {
ArrayList<String> values = new ArrayList<>();
values.add(value);
headers.put(name, values);
} else {
headers.get(name).add(value);
}
}
@Override
public boolean containsHeader(String name) {
return headers.containsKey(name);
}
}

View File

@ -94,7 +94,7 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport {
public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel, Config.Socket socketConfig) throws IOException {
NioHttpChannel httpChannel = new NioHttpChannel(channel);
HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this,
handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
handlingSettings, selector.getTaskScheduler(), threadPool::relativeTimeInNanos);
final NioChannelHandler handler;
if (ipFilter != null) {
handler = new NioIPFilter(httpHandler, socketConfig.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME);