diff --git a/core/src/main/java/org/elasticsearch/http/netty/cors/CorsHandler.java b/core/src/main/java/org/elasticsearch/http/netty/cors/CorsHandler.java index b04d9013c4f..7ecc9b1fd5d 100644 --- a/core/src/main/java/org/elasticsearch/http/netty/cors/CorsHandler.java +++ b/core/src/main/java/org/elasticsearch/http/netty/cors/CorsHandler.java @@ -31,6 +31,7 @@ import org.jboss.netty.handler.codec.http.HttpMethod; import org.jboss.netty.handler.codec.http.HttpRequest; import org.jboss.netty.handler.codec.http.HttpResponse; +import java.util.regex.Pattern; import java.util.stream.Collectors; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ACCESS_CONTROL_ALLOW_CREDENTIALS; @@ -38,6 +39,7 @@ import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ACCESS_CONTRO import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ACCESS_CONTROL_ALLOW_METHODS; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ACCESS_CONTROL_MAX_AGE; +import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.HOST; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.ORIGIN; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.USER_AGENT; import static org.jboss.netty.handler.codec.http.HttpHeaders.Names.VARY; @@ -55,8 +57,9 @@ import static org.jboss.netty.handler.codec.http.HttpResponseStatus.OK; public class CorsHandler extends SimpleChannelUpstreamHandler { public static final String ANY_ORIGIN = "*"; - private final CorsConfig config; + private static Pattern SCHEME_PATTERN = Pattern.compile("^https?://"); + private final CorsConfig config; private HttpRequest request; /** @@ -96,7 +99,7 @@ public class CorsHandler extends SimpleChannelUpstreamHandler { final String originHeaderVal; if (config.isAnyOriginSupported()) { originHeaderVal = ANY_ORIGIN; - } else if (config.isOriginAllowed(originHeader)) { + } else if (config.isOriginAllowed(originHeader) || isSameOrigin(originHeader, request.headers().get(HOST))) { originHeaderVal = originHeader; } else { originHeaderVal = null; @@ -129,6 +132,17 @@ public class CorsHandler extends SimpleChannelUpstreamHandler { .addListener(ChannelFutureListener.CLOSE); } + private static boolean isSameOrigin(final String origin, final String host) { + if (Strings.isNullOrEmpty(host) == false) { + // strip protocol from origin + final String originDomain = SCHEME_PATTERN.matcher(origin).replaceFirst(""); + if (host.equals(originDomain)) { + return true; + } + } + return false; + } + /** * This is a non CORS specification feature which enables the setting of preflight * response headers that might be required by intermediaries. @@ -179,6 +193,11 @@ public class CorsHandler extends SimpleChannelUpstreamHandler { return true; } + // if the origin is the same as the host of the request, then allow + if (isSameOrigin(origin, request.headers().get(HOST))) { + return true; + } + return config.isOriginAllowed(origin); } diff --git a/core/src/test/java/org/elasticsearch/http/netty/NettyHttpChannelTests.java b/core/src/test/java/org/elasticsearch/http/netty/NettyHttpChannelTests.java index ce9051ad189..f809ec85280 100644 --- a/core/src/test/java/org/elasticsearch/http/netty/NettyHttpChannelTests.java +++ b/core/src/test/java/org/elasticsearch/http/netty/NettyHttpChannelTests.java @@ -63,8 +63,6 @@ import static org.hamcrest.Matchers.nullValue; public class NettyHttpChannelTests extends ESTestCase { - private static final String ORIGIN = "remote-host"; - private NetworkService networkService; private ThreadPool threadPool; private MockBigArrays bigArrays; @@ -93,34 +91,67 @@ public class NettyHttpChannelTests extends ESTestCase { Settings settings = Settings.builder() .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) .build(); - HttpResponse response = execRequestWithCors(settings, ORIGIN); + HttpResponse response = execRequestWithCors(settings, "remote-host", "request-host"); // inspect response and validate assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); } public void testCorsEnabledWithAllowOrigins() { - final String originValue = ORIGIN; + final String originValue = "remote-host"; // create a http transport with CORS enabled and allow origin configured Settings settings = Settings.builder() .put(SETTING_CORS_ENABLED.getKey(), true) .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) .build(); - HttpResponse response = execRequestWithCors(settings, originValue); + HttpResponse response = execRequestWithCors(settings, originValue, "request-host"); // inspect response and validate assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); assertThat(allowedOrigins, is(originValue)); } + public void testCorsAllowOriginWithSameHost() { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = execRequestWithCors(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = execRequestWithCors(settings, originValue, host); + assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = execRequestWithCors(settings, originValue, host); + assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = execRequestWithCors(settings, originValue, host); + assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + public void testThatStringLiteralWorksOnMatch() { - final String originValue = ORIGIN; + final String originValue = "remote-host"; Settings settings = Settings.builder() .put(SETTING_CORS_ENABLED.getKey(), true) .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) .build(); - HttpResponse response = execRequestWithCors(settings, originValue); + HttpResponse response = execRequestWithCors(settings, originValue, "request-host"); // inspect response and validate assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); @@ -134,7 +165,7 @@ public class NettyHttpChannelTests extends ESTestCase { .put(SETTING_CORS_ENABLED.getKey(), true) .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) .build(); - HttpResponse response = execRequestWithCors(settings, originValue); + HttpResponse response = execRequestWithCors(settings, originValue, "request-host"); // inspect response and validate assertThat(response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); String allowedOrigins = response.headers().get(HttpHeaders.Names.ACCESS_CONTROL_ALLOW_ORIGIN); @@ -169,12 +200,13 @@ public class NettyHttpChannelTests extends ESTestCase { assertThat(response.headers().get(HttpHeaders.Names.CONTENT_TYPE), equalTo(resp.contentType())); } - private HttpResponse execRequestWithCors(final Settings settings, final String originValue) { + private HttpResponse execRequestWithCors(final Settings settings, final String originValue, final String host) { // construct request and send it over the transport layer httpServerTransport = new NettyHttpServerTransport(settings, networkService, bigArrays, threadPool); HttpRequest httpRequest = new TestHttpRequest(); - httpRequest.headers().add(HttpHeaders.Names.ORIGIN, ORIGIN); + httpRequest.headers().add(HttpHeaders.Names.ORIGIN, originValue); httpRequest.headers().add(HttpHeaders.Names.USER_AGENT, "Mozilla fake"); + httpRequest.headers().add(HttpHeaders.Names.HOST, host); WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); NettyHttpRequest request = new NettyHttpRequest(httpRequest, writeCapturingChannel);