From 4877cec3e82554cfc7a03f3552b3c62198af0d28 Mon Sep 17 00:00:00 2001 From: David Turner Date: Thu, 14 Jun 2018 13:41:25 +0100 Subject: [PATCH 1/8] More detailed tracing when writing metadata (#31319) Packaging tests are occasionally failing (#30295) because of very slow index template creation. It looks like the slow part is updating the on-disk cluster state, and this change will help to confirm this. --- .../gateway/MetaDataStateFormat.java | 21 ++++++++++++------- .../gateway/MetaStateService.java | 2 ++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java b/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java index 0821b176e75..e048512e638 100644 --- a/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java +++ b/server/src/main/java/org/elasticsearch/gateway/MetaDataStateFormat.java @@ -29,6 +29,7 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.OutputStreamIndexOutput; import org.apache.lucene.store.SimpleFSDirectory; +import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.bytes.BytesArray; @@ -76,6 +77,7 @@ public abstract class MetaDataStateFormat { private final String prefix; private final Pattern stateFilePattern; + private static final Logger logger = Loggers.getLogger(MetaDataStateFormat.class); /** * Creates a new {@link MetaDataStateFormat} instance @@ -134,6 +136,7 @@ public abstract class MetaDataStateFormat { IOUtils.fsync(tmpStatePath, false); // fsync the state file Files.move(tmpStatePath, finalStatePath, StandardCopyOption.ATOMIC_MOVE); IOUtils.fsync(stateLocation, true); + logger.trace("written state to {}", finalStatePath); for (int i = 1; i < locations.length; i++) { stateLocation = locations[i].resolve(STATE_DIR_NAME); Files.createDirectories(stateLocation); @@ -145,12 +148,15 @@ public abstract class MetaDataStateFormat { // we are on the same FileSystem / Partition here we can do an atomic move Files.move(tmpPath, finalPath, StandardCopyOption.ATOMIC_MOVE); IOUtils.fsync(stateLocation, true); + logger.trace("copied state to {}", finalPath); } finally { Files.deleteIfExists(tmpPath); + logger.trace("cleaned up {}", tmpPath); } } } finally { Files.deleteIfExists(tmpStatePath); + logger.trace("cleaned up {}", tmpStatePath); } cleanupOldFiles(prefix, fileName, locations); } @@ -211,20 +217,19 @@ public abstract class MetaDataStateFormat { } private void cleanupOldFiles(final String prefix, final String currentStateFile, Path[] locations) throws IOException { - final DirectoryStream.Filter filter = new DirectoryStream.Filter() { - @Override - public boolean accept(Path entry) throws IOException { - final String entryFileName = entry.getFileName().toString(); - return Files.isRegularFile(entry) - && entryFileName.startsWith(prefix) // only state files - && currentStateFile.equals(entryFileName) == false; // keep the current state file around - } + final DirectoryStream.Filter filter = entry -> { + final String entryFileName = entry.getFileName().toString(); + return Files.isRegularFile(entry) + && entryFileName.startsWith(prefix) // only state files + && currentStateFile.equals(entryFileName) == false; // keep the current state file around }; // now clean up the old files for (Path dataLocation : locations) { + logger.trace("cleanupOldFiles: cleaning up {}", dataLocation); try (DirectoryStream stream = Files.newDirectoryStream(dataLocation.resolve(STATE_DIR_NAME), filter)) { for (Path stateFile : stream) { Files.deleteIfExists(stateFile); + logger.trace("cleanupOldFiles: cleaned up {}", stateFile); } } } diff --git a/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java b/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java index 00b981175f2..fd1698bb006 100644 --- a/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java +++ b/server/src/main/java/org/elasticsearch/gateway/MetaStateService.java @@ -123,6 +123,7 @@ public class MetaStateService extends AbstractComponent { try { IndexMetaData.FORMAT.write(indexMetaData, nodeEnv.indexPaths(indexMetaData.getIndex())); + logger.trace("[{}] state written", index); } catch (Exception ex) { logger.warn(() -> new ParameterizedMessage("[{}]: failed to write index state", index), ex); throw new IOException("failed to write state for [" + index + "]", ex); @@ -136,6 +137,7 @@ public class MetaStateService extends AbstractComponent { logger.trace("[_global] writing state, reason [{}]", reason); try { MetaData.FORMAT.write(metaData, nodeEnv.nodeDataPaths()); + logger.trace("[_global] state written"); } catch (Exception ex) { logger.warn("[_global]: failed to write global state", ex); throw new IOException("failed to write global state", ex); From 375d09c588a58a0f490bac29fb535922819e3453 Mon Sep 17 00:00:00 2001 From: Simon Willnauer Date: Thu, 14 Jun 2018 16:21:28 +0200 Subject: [PATCH 2/8] [TEST] Fix RemoteClusterClientTests#testEnsureWeReconnect Closes #29547 --- .../transport/RemoteClusterClientTests.java | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java index a497e509c15..8cfec0a07f9 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterClientTests.java @@ -30,6 +30,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import java.util.Collections; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import static org.elasticsearch.transport.RemoteClusterConnectionTests.startTransport; @@ -69,7 +70,6 @@ public class RemoteClusterClientTests extends ESTestCase { } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/29547") public void testEnsureWeReconnect() throws Exception { Settings remoteSettings = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "foo_bar_cluster").build(); try (MockTransportService remoteTransport = startTransport("remote_node", Collections.emptyList(), Version.CURRENT, threadPool, @@ -79,17 +79,35 @@ public class RemoteClusterClientTests extends ESTestCase { .put(RemoteClusterService.ENABLE_REMOTE_CLUSTERS.getKey(), true) .put("search.remote.test.seeds", remoteNode.getAddress().getAddress() + ":" + remoteNode.getAddress().getPort()).build(); try (MockTransportService service = MockTransportService.createNewService(localSettings, Version.CURRENT, threadPool, null)) { + Semaphore semaphore = new Semaphore(1); service.start(); + service.addConnectionListener(new TransportConnectionListener() { + @Override + public void onNodeDisconnected(DiscoveryNode node) { + if (remoteNode.equals(node)) { + semaphore.release(); + } + } + }); + // this test is not perfect since we might reconnect concurrently but it will fail most of the time if we don't have + // the right calls in place in the RemoteAwareClient service.acceptIncomingRequests(); - service.disconnectFromNode(remoteNode); - RemoteClusterService remoteClusterService = service.getRemoteClusterService(); - assertBusy(() -> assertFalse(remoteClusterService.isRemoteNodeConnected("test", remoteNode))); - Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test"); - ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get(); - assertNotNull(clusterStateResponse); - assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value()); + for (int i = 0; i < 10; i++) { + semaphore.acquire(); + try { + service.disconnectFromNode(remoteNode); + semaphore.acquire(); + RemoteClusterService remoteClusterService = service.getRemoteClusterService(); + Client client = remoteClusterService.getRemoteClusterClient(threadPool, "test"); + ClusterStateResponse clusterStateResponse = client.admin().cluster().prepareState().execute().get(); + assertNotNull(clusterStateResponse); + assertEquals("foo_bar_cluster", clusterStateResponse.getState().getClusterName().value()); + assertTrue(remoteClusterService.isRemoteNodeConnected("test", remoteNode)); + } finally { + semaphore.release(); + } + } } } } - } From f7a0cafe557275cf1d3b6f87209f9c41ecb39b8d Mon Sep 17 00:00:00 2001 From: Costin Leau Date: Thu, 14 Jun 2018 18:07:29 +0300 Subject: [PATCH 3/8] SQL: Fix build on Java 10 Due to a runtime classpath clash, featureAware task was failing on JVMs higher than 1.8 (since the ASM version from Painless was used instead which does not recognized Java 9 or 10 bytecode) causing the task to fail. This commit excludes the ASM dependency (since it's not used by SQL itself). --- x-pack/plugin/build.gradle | 4 +--- x-pack/plugin/sql/build.gradle | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/build.gradle b/x-pack/plugin/build.gradle index de4d3ada51a..ac423c42811 100644 --- a/x-pack/plugin/build.gradle +++ b/x-pack/plugin/build.gradle @@ -43,9 +43,7 @@ subprojects { final FileCollection classDirectories = project.files(files).filter { it.exists() } doFirst { - String cp = project.configurations.featureAwarePlugin.asPath - cp = cp.replaceAll(":[^:]*/asm-debug-all-5.1.jar:", ":") - args('-cp', cp, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck') + args('-cp', project.configurations.featureAwarePlugin.asPath, 'org.elasticsearch.xpack.test.feature_aware.FeatureAwareCheck') classDirectories.each { args it.getAbsolutePath() } } doLast { diff --git a/x-pack/plugin/sql/build.gradle b/x-pack/plugin/sql/build.gradle index 8b406235985..19dd1a08ec6 100644 --- a/x-pack/plugin/sql/build.gradle +++ b/x-pack/plugin/sql/build.gradle @@ -20,7 +20,10 @@ integTest.enabled = false dependencies { compileOnly "org.elasticsearch.plugin:x-pack-core:${version}" - compileOnly project(':modules:lang-painless') + compileOnly(project(':modules:lang-painless')) { + // exclude ASM to not affect featureAware task on Java 10+ + exclude group: "org.ow2.asm" + } compile project('sql-proto') compile "org.elasticsearch.plugin:aggs-matrix-stats-client:${version}" compile "org.antlr:antlr4-runtime:4.5.3" From 9b293275af8990917ee2152dbf53d8642bda264d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 14 Jun 2018 16:52:32 +0100 Subject: [PATCH 4/8] [ML] Add description to ML filters (#31330) This adds a `description` to ML filters in order to allow users to describe their filters in a human readable form which is also editable (filter updates to be added shortly). --- .../xpack/core/ml/job/config/MlFilter.java | 47 ++++++++++++++++--- .../action/GetFiltersActionResponseTests.java | 5 +- .../action/PutFilterActionRequestTests.java | 14 +----- .../core/ml/job/config/MlFilterTests.java | 19 ++++++-- .../xpack/ml/integration/JobProviderIT.java | 4 +- .../xpack/ml/job/JobManagerTests.java | 2 +- .../ControlMsgToProcessWriterTests.java | 4 +- .../writer/FieldConfigWriterTests.java | 4 +- .../writer/MlFilterWriterTests.java | 5 +- .../rest-api-spec/test/ml/filter_crud.yml | 2 + .../ml/integration/DetectionRulesIT.java | 6 +-- 11 files changed, 75 insertions(+), 37 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java index de6ee3d509c..991f421265e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/MlFilter.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.job.config; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -30,6 +31,7 @@ public class MlFilter implements ToXContentObject, Writeable { public static final ParseField TYPE = new ParseField("type"); public static final ParseField ID = new ParseField("filter_id"); + public static final ParseField DESCRIPTION = new ParseField("description"); public static final ParseField ITEMS = new ParseField("items"); // For QueryPage @@ -43,27 +45,38 @@ public class MlFilter implements ToXContentObject, Writeable { parser.declareString((builder, s) -> {}, TYPE); parser.declareString(Builder::setId, ID); + parser.declareStringOrNull(Builder::setDescription, DESCRIPTION); parser.declareStringArray(Builder::setItems, ITEMS); return parser; } private final String id; + private final String description; private final List items; - public MlFilter(String id, List items) { + public MlFilter(String id, String description, List items) { this.id = Objects.requireNonNull(id, ID.getPreferredName() + " must not be null"); + this.description = description; this.items = Objects.requireNonNull(items, ITEMS.getPreferredName() + " must not be null"); } public MlFilter(StreamInput in) throws IOException { id = in.readString(); + if (in.getVersion().onOrAfter(Version.V_6_4_0)) { + description = in.readOptionalString(); + } else { + description = null; + } items = Arrays.asList(in.readStringArray()); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); + if (out.getVersion().onOrAfter(Version.V_6_4_0)) { + out.writeOptionalString(description); + } out.writeStringArray(items.toArray(new String[items.size()])); } @@ -71,6 +84,9 @@ public class MlFilter implements ToXContentObject, Writeable { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(ID.getPreferredName(), id); + if (description != null) { + builder.field(DESCRIPTION.getPreferredName(), description); + } builder.field(ITEMS.getPreferredName(), items); if (params.paramAsBoolean(MlMetaIndex.INCLUDE_TYPE_KEY, false)) { builder.field(TYPE.getPreferredName(), FILTER_TYPE); @@ -83,6 +99,10 @@ public class MlFilter implements ToXContentObject, Writeable { return id; } + public String getDescription() { + return description; + } + public List getItems() { return new ArrayList<>(items); } @@ -98,12 +118,12 @@ public class MlFilter implements ToXContentObject, Writeable { } MlFilter other = (MlFilter) obj; - return id.equals(other.id) && items.equals(other.items); + return id.equals(other.id) && Objects.equals(description, other.description) && items.equals(other.items); } @Override public int hashCode() { - return Objects.hash(id, items); + return Objects.hash(id, description, items); } public String documentId() { @@ -114,30 +134,45 @@ public class MlFilter implements ToXContentObject, Writeable { return DOCUMENT_ID_PREFIX + filterId; } + public static Builder builder(String filterId) { + return new Builder().setId(filterId); + } + public static class Builder { private String id; + private String description; private List items = Collections.emptyList(); + private Builder() {} + public Builder setId(String id) { this.id = id; return this; } - private Builder() {} - @Nullable public String getId() { return id; } + public Builder setDescription(String description) { + this.description = description; + return this; + } + public Builder setItems(List items) { this.items = items; return this; } + public Builder setItems(String... items) { + this.items = Arrays.asList(items); + return this; + } + public MlFilter build() { - return new MlFilter(id, items); + return new MlFilter(id, description, items); } } } \ No newline at end of file diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java index c8465c87587..7bda0f6e7de 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetFiltersActionResponseTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.test.AbstractStreamableTestCase; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction.Response; import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.job.config.MlFilter; +import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests; import java.util.Collections; @@ -17,9 +18,7 @@ public class GetFiltersActionResponseTests extends AbstractStreamableTestCase result; - - MlFilter doc = new MlFilter( - randomAlphaOfLengthBetween(1, 20), Collections.singletonList(randomAlphaOfLengthBetween(1, 20))); + MlFilter doc = MlFilterTests.createRandom(); result = new QueryPage<>(Collections.singletonList(doc), 1, MlFilter.RESULTS_FIELD); return new Response(result); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java index 21845922470..dfc3f5f37f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutFilterActionRequestTests.java @@ -8,10 +8,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractStreamableXContentTestCase; import org.elasticsearch.xpack.core.ml.action.PutFilterAction.Request; -import org.elasticsearch.xpack.core.ml.job.config.MlFilter; - -import java.util.ArrayList; -import java.util.List; +import org.elasticsearch.xpack.core.ml.job.config.MlFilterTests; public class PutFilterActionRequestTests extends AbstractStreamableXContentTestCase { @@ -19,13 +16,7 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC @Override protected Request createTestInstance() { - int size = randomInt(10); - List items = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - items.add(randomAlphaOfLengthBetween(1, 20)); - } - MlFilter filter = new MlFilter(filterId, items); - return new PutFilterAction.Request(filter); + return new PutFilterAction.Request(MlFilterTests.createRandom(filterId)); } @Override @@ -42,5 +33,4 @@ public class PutFilterActionRequestTests extends AbstractStreamableXContentTestC protected Request doParseInstance(XContentParser parser) { return PutFilterAction.Request.parseRequest(filterId, parser); } - } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java index 1b61e3ec9a4..78d87b82839 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/MlFilterTests.java @@ -26,12 +26,25 @@ public class MlFilterTests extends AbstractSerializingTestCase { @Override protected MlFilter createTestInstance() { + return createRandom(); + } + + public static MlFilter createRandom() { + return createRandom(randomAlphaOfLengthBetween(1, 20)); + } + + public static MlFilter createRandom(String filterId) { + String description = null; + if (randomBoolean()) { + description = randomAlphaOfLength(20); + } + int size = randomInt(10); List items = new ArrayList<>(size); for (int i = 0; i < size; i++) { items.add(randomAlphaOfLengthBetween(1, 20)); } - return new MlFilter(randomAlphaOfLengthBetween(1, 20), items); + return new MlFilter(filterId, description, items); } @Override @@ -45,13 +58,13 @@ public class MlFilterTests extends AbstractSerializingTestCase { } public void testNullId() { - NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, Collections.emptyList())); + NullPointerException ex = expectThrows(NullPointerException.class, () -> new MlFilter(null, "", Collections.emptyList())); assertEquals(MlFilter.ID.getPreferredName() + " must not be null", ex.getMessage()); } public void testNullItems() { NullPointerException ex = - expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), null)); + expectThrows(NullPointerException.class, () -> new MlFilter(randomAlphaOfLengthBetween(1, 20), "", null)); assertEquals(MlFilter.ITEMS.getPreferredName() + " must not be null", ex.getMessage()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java index 7e0dc453f07..856b930ac49 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/JobProviderIT.java @@ -385,8 +385,8 @@ public class JobProviderIT extends MlSingleNodeTestCase { indexScheduledEvents(events); List filters = new ArrayList<>(); - filters.add(new MlFilter("fruit", Arrays.asList("apple", "pear"))); - filters.add(new MlFilter("tea", Arrays.asList("green", "builders"))); + filters.add(MlFilter.builder("fruit").setItems("apple", "pear").build()); + filters.add(MlFilter.builder("tea").setItems("green", "builders").build()); indexFilters(filters); DataCounts earliestCounts = DataCountsTests.createTestInstance(jobId); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java index 454f941d6c8..42b0a56f49a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobManagerTests.java @@ -210,7 +210,7 @@ public class JobManagerTests extends ESTestCase { JobManager jobManager = createJobManager(); - MlFilter filter = new MlFilter("foo_filter", Arrays.asList("a", "b")); + MlFilter filter = MlFilter.builder("foo_filter").setItems("a", "b").build(); jobManager.updateProcessOnFilterChanged(filter); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java index 8c32a5bb40d..3d08f5a1c25 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/ControlMsgToProcessWriterTests.java @@ -207,8 +207,8 @@ public class ControlMsgToProcessWriterTests extends ESTestCase { public void testWriteUpdateFiltersMessage() throws IOException { ControlMsgToProcessWriter writer = new ControlMsgToProcessWriter(lengthEncodedWriter, 2); - MlFilter filter1 = new MlFilter("filter_1", Arrays.asList("a")); - MlFilter filter2 = new MlFilter("filter_2", Arrays.asList("b", "c")); + MlFilter filter1 = MlFilter.builder("filter_1").setItems("a").build(); + MlFilter filter2 = MlFilter.builder("filter_2").setItems("b", "c").build(); writer.writeUpdateFiltersMessage(Arrays.asList(filter1, filter2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java index bf08d09bf09..d26dbb203c8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/FieldConfigWriterTests.java @@ -220,8 +220,8 @@ public class FieldConfigWriterTests extends ESTestCase { AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(d)); analysisConfig = builder.build(); - filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); - filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); + filters.add(MlFilter.builder("filter_1").setItems("a", "b").build()); + filters.add(MlFilter.builder("filter_2").setItems("c", "d").build()); writer = mock(OutputStreamWriter.class); createFieldConfigWriter().write(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java index f22f7d85090..12ceb12f462 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/MlFilterWriterTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.xpack.core.ml.job.config.MlFilter; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -28,8 +27,8 @@ public class MlFilterWriterTests extends ESTestCase { public void testWrite() throws IOException { List filters = new ArrayList<>(); - filters.add(new MlFilter("filter_1", Arrays.asList("a", "b"))); - filters.add(new MlFilter("filter_2", Arrays.asList("c", "d"))); + filters.add(MlFilter.builder("filter_1").setItems("a", "b").build()); + filters.add(MlFilter.builder("filter_2").setItems("c", "d").build()); StringBuilder buffer = new StringBuilder(); new MlFilterWriter(filters, buffer).write(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml index d3165260f4b..a1f7eee0dcc 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/filter_crud.yml @@ -32,6 +32,7 @@ setup: filter_id: filter-foo2 body: > { + "description": "This filter has a description", "items": ["123", "lmnop"] } @@ -76,6 +77,7 @@ setup: - match: filters.1: filter_id: "filter-foo2" + description: "This filter has a description" items: ["123", "lmnop"] - do: diff --git a/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java b/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java index aa53d6255cb..b99170546df 100644 --- a/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java +++ b/x-pack/qa/ml-native-tests/src/test/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java @@ -120,7 +120,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { } public void testScope() throws Exception { - MlFilter safeIps = new MlFilter("safe_ips", Arrays.asList("111.111.111.111", "222.222.222.222")); + MlFilter safeIps = MlFilter.builder("safe_ips").setItems("111.111.111.111", "222.222.222.222").build(); assertThat(putMlFilter(safeIps), is(true)); DetectionRule rule = new DetectionRule.Builder(RuleScope.builder().include("ip", "safe_ips")).build(); @@ -178,7 +178,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { assertThat(records.get(0).getOverFieldValue(), equalTo("333.333.333.333")); // Now let's update the filter - MlFilter updatedFilter = new MlFilter(safeIps.getId(), Collections.singletonList("333.333.333.333")); + MlFilter updatedFilter = MlFilter.builder(safeIps.getId()).setItems("333.333.333.333").build(); assertThat(putMlFilter(updatedFilter), is(true)); // Wait until the notification that the process was updated is indexed @@ -229,7 +229,7 @@ public class DetectionRulesIT extends MlNativeAutodetectIntegTestCase { public void testScopeAndCondition() throws IOException { // We have 2 IPs and they're both safe-listed. List ips = Arrays.asList("111.111.111.111", "222.222.222.222"); - MlFilter safeIps = new MlFilter("safe_ips", ips); + MlFilter safeIps = MlFilter.builder("safe_ips").setItems(ips).build(); assertThat(putMlFilter(safeIps), is(true)); // Ignore if ip in safe list AND actual < 10. From 8f886cd4be82ff30e98d812a1deac0a70ad92bc3 Mon Sep 17 00:00:00 2001 From: Yannick Welsch Date: Thu, 14 Jun 2018 18:32:35 +0200 Subject: [PATCH 5/8] Treat ack timeout more like a publish timeout (#31303) This commit changes the ack timeout mechanism so that its behavior is closer to the publish timeout, i.e., it only comes into play after committing a cluster state. This ensures for example that an index creation request with a low (ack) timeout value does not return before the cluster state that contains information about the newly created index is even committed. --- .../cluster/service/MasterService.java | 72 ++++++---- .../elasticsearch/discovery/Discovery.java | 14 ++ .../discovery/single/SingleNodeDiscovery.java | 2 + .../zen/PublishClusterStateAction.java | 14 +- .../cluster/service/MasterServiceTests.java | 129 ++++++++++++++++++ .../zen/PublishClusterStateActionTests.java | 9 ++ 6 files changed, 207 insertions(+), 33 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java index 1757548c28b..4432d864fd3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java +++ b/server/src/main/java/org/elasticsearch/cluster/service/MasterService.java @@ -50,7 +50,6 @@ import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.threadpool.ThreadPool; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -365,28 +364,11 @@ public class MasterService extends AbstractLifecycleComponent { } public Discovery.AckListener createAckListener(ThreadPool threadPool, ClusterState newClusterState) { - ArrayList ackListeners = new ArrayList<>(); - - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - nonFailedTasks.stream().filter(task -> task.listener instanceof AckedClusterStateTaskListener).forEach(task -> { - final AckedClusterStateTaskListener ackedListener = (AckedClusterStateTaskListener) task.listener; - if (ackedListener.ackTimeout() == null || ackedListener.ackTimeout().millis() == 0) { - ackedListener.onAckTimeout(); - } else { - try { - ackListeners.add(new AckCountDownListener(ackedListener, newClusterState.version(), newClusterState.nodes(), - threadPool)); - } catch (EsRejectedExecutionException ex) { - if (logger.isDebugEnabled()) { - logger.debug("Couldn't schedule timeout thread - node might be shutting down", ex); - } - //timeout straightaway, otherwise we could wait forever as the timeout thread has not started - ackedListener.onAckTimeout(); - } - } - }); - - return new DelegatingAckListener(ackListeners); + return new DelegatingAckListener(nonFailedTasks.stream() + .filter(task -> task.listener instanceof AckedClusterStateTaskListener) + .map(task -> new AckCountDownListener((AckedClusterStateTaskListener) task.listener, newClusterState.version(), + newClusterState.nodes(), threadPool)) + .collect(Collectors.toList())); } public boolean clusterStateUnchanged() { @@ -549,6 +531,13 @@ public class MasterService extends AbstractLifecycleComponent { this.listeners = listeners; } + @Override + public void onCommit(TimeValue commitTime) { + for (Discovery.AckListener listener : listeners) { + listener.onCommit(commitTime); + } + } + @Override public void onNodeAck(DiscoveryNode node, @Nullable Exception e) { for (Discovery.AckListener listener : listeners) { @@ -564,14 +553,16 @@ public class MasterService extends AbstractLifecycleComponent { private final AckedClusterStateTaskListener ackedTaskListener; private final CountDown countDown; private final DiscoveryNode masterNode; + private final ThreadPool threadPool; private final long clusterStateVersion; - private final Future ackTimeoutCallback; + private volatile Future ackTimeoutCallback; private Exception lastFailure; AckCountDownListener(AckedClusterStateTaskListener ackedTaskListener, long clusterStateVersion, DiscoveryNodes nodes, ThreadPool threadPool) { this.ackedTaskListener = ackedTaskListener; this.clusterStateVersion = clusterStateVersion; + this.threadPool = threadPool; this.masterNode = nodes.getMasterNode(); int countDown = 0; for (DiscoveryNode node : nodes) { @@ -581,8 +572,27 @@ public class MasterService extends AbstractLifecycleComponent { } } logger.trace("expecting {} acknowledgements for cluster_state update (version: {})", countDown, clusterStateVersion); - this.countDown = new CountDown(countDown); - this.ackTimeoutCallback = threadPool.schedule(ackedTaskListener.ackTimeout(), ThreadPool.Names.GENERIC, () -> onTimeout()); + this.countDown = new CountDown(countDown + 1); // we also wait for onCommit to be called + } + + @Override + public void onCommit(TimeValue commitTime) { + TimeValue ackTimeout = ackedTaskListener.ackTimeout(); + if (ackTimeout == null) { + ackTimeout = TimeValue.ZERO; + } + final TimeValue timeLeft = TimeValue.timeValueNanos(Math.max(0, ackTimeout.nanos() - commitTime.nanos())); + if (timeLeft.nanos() == 0L) { + onTimeout(); + } else if (countDown.countDown()) { + finish(); + } else { + this.ackTimeoutCallback = threadPool.schedule(timeLeft, ThreadPool.Names.GENERIC, this::onTimeout); + // re-check if onNodeAck has not completed while we were scheduling the timeout + if (countDown.isCountedDown()) { + FutureUtils.cancel(ackTimeoutCallback); + } + } } @Override @@ -599,12 +609,16 @@ public class MasterService extends AbstractLifecycleComponent { } if (countDown.countDown()) { - logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); - FutureUtils.cancel(ackTimeoutCallback); - ackedTaskListener.onAllNodesAcked(lastFailure); + finish(); } } + private void finish() { + logger.trace("all expected nodes acknowledged cluster_state update (version: {})", clusterStateVersion); + FutureUtils.cancel(ackTimeoutCallback); + ackedTaskListener.onAllNodesAcked(lastFailure); + } + public void onTimeout() { if (countDown.fastForward()) { logger.trace("timeout waiting for acknowledgement for cluster_state update (version: {})", clusterStateVersion); diff --git a/server/src/main/java/org/elasticsearch/discovery/Discovery.java b/server/src/main/java/org/elasticsearch/discovery/Discovery.java index 9c708760324..b58f61bac89 100644 --- a/server/src/main/java/org/elasticsearch/discovery/Discovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/Discovery.java @@ -25,6 +25,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.unit.TimeValue; import java.io.IOException; @@ -48,6 +49,19 @@ public interface Discovery extends LifecycleComponent { void publish(ClusterChangedEvent clusterChangedEvent, AckListener ackListener); interface AckListener { + /** + * Should be called when the discovery layer has committed the clusters state (i.e. even if this publication fails, + * it is guaranteed to appear in future publications). + * @param commitTime the time it took to commit the cluster state + */ + void onCommit(TimeValue commitTime); + + /** + * Should be called whenever the discovery layer receives confirmation from a node that it has successfully applied + * the cluster state. In case of failures, an exception should be provided as parameter. + * @param node the node + * @param e the optional exception + */ void onNodeAck(DiscoveryNode node, @Nullable Exception e); } diff --git a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java index cd775e29f5a..d7c37febb5d 100644 --- a/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java +++ b/server/src/main/java/org/elasticsearch/discovery/single/SingleNodeDiscovery.java @@ -30,6 +30,7 @@ import org.elasticsearch.cluster.service.ClusterApplier.ClusterApplyListener; import org.elasticsearch.cluster.service.MasterService; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoveryStats; import org.elasticsearch.transport.TransportService; @@ -61,6 +62,7 @@ public class SingleNodeDiscovery extends AbstractLifecycleComponent implements D public synchronized void publish(final ClusterChangedEvent event, final AckListener ackListener) { clusterState = event.state(); + ackListener.onCommit(TimeValue.ZERO); CountDownLatch latch = new CountDownLatch(1); ClusterApplyListener listener = new ClusterApplyListener() { diff --git a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java index cd87a415263..5398b2a057a 100644 --- a/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java +++ b/server/src/main/java/org/elasticsearch/discovery/zen/PublishClusterStateAction.java @@ -158,7 +158,8 @@ public class PublishClusterStateAction extends AbstractComponent { } try { - innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, sendFullVersion, serializedStates, serializedDiffs); + innerPublish(clusterChangedEvent, nodesToPublishTo, sendingController, ackListener, sendFullVersion, serializedStates, + serializedDiffs); } catch (Discovery.FailedToCommitClusterStateException t) { throw t; } catch (Exception e) { @@ -173,8 +174,9 @@ public class PublishClusterStateAction extends AbstractComponent { } private void innerPublish(final ClusterChangedEvent clusterChangedEvent, final Set nodesToPublishTo, - final SendingController sendingController, final boolean sendFullVersion, - final Map serializedStates, final Map serializedDiffs) { + final SendingController sendingController, final Discovery.AckListener ackListener, + final boolean sendFullVersion, final Map serializedStates, + final Map serializedDiffs) { final ClusterState clusterState = clusterChangedEvent.state(); final ClusterState previousState = clusterChangedEvent.previousState(); @@ -195,8 +197,12 @@ public class PublishClusterStateAction extends AbstractComponent { sendingController.waitForCommit(discoverySettings.getCommitTimeout()); + final long commitTime = System.nanoTime() - publishingStartInNanos; + + ackListener.onCommit(TimeValue.timeValueNanos(commitTime)); + try { - long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - (System.nanoTime() - publishingStartInNanos)); + long timeLeftInNanos = Math.max(0, publishTimeout.nanos() - commitTime); final BlockingClusterStatePublishResponseHandler publishResponseHandler = sendingController.getPublishResponseHandler(); sendingController.setPublishingTimedOut(!publishResponseHandler.awaitAllNodes(TimeValue.timeValueNanos(timeLeftInNanos))); if (sendingController.getPublishingTimedOut()) { diff --git a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java index 1b747f22687..f75363c7ab5 100644 --- a/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java @@ -22,6 +22,7 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.ClusterState; @@ -39,6 +40,7 @@ import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.BaseFuture; +import org.elasticsearch.discovery.Discovery; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockLogAppender; import org.elasticsearch.test.junit.annotations.TestLogging; @@ -65,6 +67,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; @@ -680,6 +683,132 @@ public class MasterServiceTests extends ESTestCase { mockAppender.assertAllExpectationsMatched(); } + public void testAcking() throws InterruptedException { + final DiscoveryNode node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + final DiscoveryNode node3 = new DiscoveryNode("node3", buildNewFakeTransportAddress(), emptyMap(), emptySet(), Version.CURRENT); + TimedMasterService timedMasterService = new TimedMasterService(Settings.builder().put("cluster.name", + MasterServiceTests.class.getSimpleName()).build(), threadPool); + ClusterState initialClusterState = ClusterState.builder(new ClusterName(MasterServiceTests.class.getSimpleName())) + .nodes(DiscoveryNodes.builder() + .add(node1) + .add(node2) + .add(node3) + .localNodeId(node1.getId()) + .masterNodeId(node1.getId())) + .blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK).build(); + final AtomicReference> publisherRef = new AtomicReference<>(); + timedMasterService.setClusterStatePublisher((cce, l) -> publisherRef.get().accept(cce, l)); + timedMasterService.setClusterStateSupplier(() -> initialClusterState); + timedMasterService.start(); + + + // check that we don't time out before even committing the cluster state + { + final CountDownLatch latch = new CountDownLatch(1); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + throw new Discovery.FailedToCommitClusterStateException("mock exception"); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return TimeValue.ZERO; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + fail(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + latch.countDown(); + } + + @Override + public void onAckTimeout() { + fail(); + } + }); + + latch.await(); + } + + // check that we timeout if commit took too long + { + final CountDownLatch latch = new CountDownLatch(2); + + final TimeValue ackTimeout = TimeValue.timeValueMillis(randomInt(100)); + + publisherRef.set((clusterChangedEvent, ackListener) -> { + ackListener.onCommit(TimeValue.timeValueMillis(ackTimeout.millis() + randomInt(100))); + ackListener.onNodeAck(node1, null); + ackListener.onNodeAck(node2, null); + ackListener.onNodeAck(node3, null); + }); + + timedMasterService.submitStateUpdateTask("test2", new AckedClusterStateUpdateTask(null, null) { + @Override + public ClusterState execute(ClusterState currentState) { + return ClusterState.builder(currentState).build(); + } + + @Override + public TimeValue ackTimeout() { + return ackTimeout; + } + + @Override + public TimeValue timeout() { + return null; + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + latch.countDown(); + } + + @Override + protected Void newResponse(boolean acknowledged) { + fail(); + return null; + } + + @Override + public void onFailure(String source, Exception e) { + fail(); + } + + @Override + public void onAckTimeout() { + latch.countDown(); + } + }); + + latch.await(); + } + + timedMasterService.close(); + } + static class TimedMasterService extends MasterService { public volatile Long currentTimeOverride = null; diff --git a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java index c8e85382994..ac1719269e7 100644 --- a/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/zen/PublishClusterStateActionTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.discovery.Discovery; import org.elasticsearch.discovery.DiscoverySettings; import org.elasticsearch.node.Node; @@ -815,9 +816,16 @@ public class PublishClusterStateActionTests extends ESTestCase { public static class AssertingAckListener implements Discovery.AckListener { private final List> errors = new CopyOnWriteArrayList<>(); private final CountDownLatch countDown; + private final CountDownLatch commitCountDown; public AssertingAckListener(int nodeCount) { countDown = new CountDownLatch(nodeCount); + commitCountDown = new CountDownLatch(1); + } + + @Override + public void onCommit(TimeValue commitTime) { + commitCountDown.countDown(); } @Override @@ -830,6 +838,7 @@ public class PublishClusterStateActionTests extends ESTestCase { public void await(long timeout, TimeUnit unit) throws InterruptedException { assertThat(awaitErrors(timeout, unit), emptyIterable()); + assertTrue(commitCountDown.await(timeout, unit)); } public List> awaitErrors(long timeout, TimeUnit unit) throws InterruptedException { From 6dd81ead74fd54247198d380b98053124a30f1ee Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Thu, 14 Jun 2018 16:22:00 -0400 Subject: [PATCH 6/8] Build: Fix the license in the pom zip and tar (#31336) For 6.3 we renamed the `tar` and `zip` distributions to `oss-tar` and `oss-zip`. Then we added new `tar` and `zip` distributions that contain x-pack and are licensed under the Elastic License. Unfortunately we accidentally generated POM files along side the new `tar` and `zip` distributions that incorrectly claimed that they were Apache 2 licensed. Oooops. This fixes the license on the POMs generated for the `tar` and `zip` distributions. --- build.gradle | 14 ++++++++++++++ distribution/archives/build.gradle | 2 ++ x-pack/build.gradle | 10 +--------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/build.gradle b/build.gradle index ec81047e3e6..9bb08cf29db 100644 --- a/build.gradle +++ b/build.gradle @@ -53,9 +53,23 @@ subprojects { description = "Elasticsearch subproject ${project.path}" } +apply plugin: 'nebula.info-scm' +String licenseCommit +if (VersionProperties.elasticsearch.toString().endsWith('-SNAPSHOT')) { + licenseCommit = scminfo.change ?: "master" // leniency for non git builds +} else { + licenseCommit = "v${version}" +} +String elasticLicenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt" + subprojects { + // Default to the apache license project.ext.licenseName = 'The Apache Software License, Version 2.0' project.ext.licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt' + + // But stick the Elastic license url in project.ext so we can get it if we need to switch to it + project.ext.elasticLicenseUrl = elasticLicenseUrl + // we only use maven publish to add tasks for pom generation plugins.withType(MavenPublishPlugin).whenPluginAdded { publishing { diff --git a/distribution/archives/build.gradle b/distribution/archives/build.gradle index c1097b68b89..71606c2c027 100644 --- a/distribution/archives/build.gradle +++ b/distribution/archives/build.gradle @@ -228,6 +228,8 @@ subprojects { check.dependsOn checkNotice if (project.name == 'zip' || project.name == 'tar') { + project.ext.licenseName = 'Elastic License' + project.ext.licenseUrl = ext.elasticLicenseUrl task checkMlCppNotice { dependsOn buildDist, checkExtraction onlyIf toolExists diff --git a/x-pack/build.gradle b/x-pack/build.gradle index 91652b9e150..6a064ff5b7c 100644 --- a/x-pack/build.gradle +++ b/x-pack/build.gradle @@ -5,14 +5,6 @@ import org.elasticsearch.gradle.precommit.LicenseHeadersTask Project xpackRootProject = project -apply plugin: 'nebula.info-scm' -final String licenseCommit -if (version.endsWith('-SNAPSHOT')) { - licenseCommit = xpackRootProject.scminfo.change ?: "master" // leniency for non git builds -} else { - licenseCommit = "v${version}" -} - subprojects { group = 'org.elasticsearch.plugin' ext.xpackRootProject = xpackRootProject @@ -21,7 +13,7 @@ subprojects { ext.xpackModule = { String moduleName -> xpackProject("plugin:${moduleName}").path } ext.licenseName = 'Elastic License' - ext.licenseUrl = "https://raw.githubusercontent.com/elastic/elasticsearch/${licenseCommit}/licenses/ELASTIC-LICENSE.txt" + ext.licenseUrl = ext.elasticLicenseUrl project.ext.licenseFile = rootProject.file('licenses/ELASTIC-LICENSE.txt') project.ext.noticeFile = xpackRootProject.file('NOTICE.txt') From fcf1e41e429b10e03c5cf9b8551636df7519b4c5 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 14 Jun 2018 15:10:02 -0600 Subject: [PATCH 7/8] Extract common http logic to server (#31311) This is related to #28898. With the addition of the http nio transport, we now have two different modules that provide http transports. Currently most of the http logic lives at the module level. However, some of this logic can live in server. In particular, some of the setting of headers, cors, and pipelining. This commit begins this moving in that direction by introducing lower level abstraction (HttpChannel, HttpRequest, and HttpResonse) that is implemented by the modules. The higher level rest request and rest channel work can live entirely in server. --- .../http/netty4/Netty4HttpChannel.java | 260 +------- .../netty4/Netty4HttpPipeliningHandler.java | 2 +- .../http/netty4/Netty4HttpRequest.java | 156 ++--- .../http/netty4/Netty4HttpRequestHandler.java | 129 +--- .../http/netty4/Netty4HttpResponse.java | 100 ++- .../netty4/Netty4HttpServerTransport.java | 54 +- .../http/netty4/cors/Netty4CorsHandler.java | 10 + .../transport/netty4/Netty4Transport.java | 2 +- .../transport/netty4/NettyTcpChannel.java | 7 +- .../http/netty4/Netty4CorsTests.java | 148 +++++ .../http/netty4/Netty4HttpChannelTests.java | 616 ------------------ .../Netty4HttpPipeliningHandlerTests.java | 26 +- .../Netty4HttpServerPipeliningTests.java | 19 +- .../Netty4HttpServerTransportTests.java | 34 - .../http/nio/HttpReadWriteHandler.java | 119 +--- .../http/nio/NioHttpChannel.java | 243 +------ .../http/nio/NioHttpPipeliningHandler.java | 2 +- .../http/nio/NioHttpRequest.java | 107 ++- .../http/nio/NioHttpResponse.java | 97 ++- .../http/nio/NioHttpServerTransport.java | 27 +- .../http/nio/cors/NioCorsHandler.java | 10 + .../http/nio/HttpReadWriteHandlerTests.java | 225 +++++-- .../http/nio/NioHttpChannelTests.java | 349 ---------- .../nio/NioHttpPipeliningHandlerTests.java | 26 +- .../http/nio/NioHttpServerTransportTests.java | 34 - .../http/AbstractHttpServerTransport.java | 103 ++- .../http/DefaultRestChannel.java | 172 +++++ .../org/elasticsearch/http/HttpChannel.java | 58 ++ .../http/HttpPipelinedMessage.java | 21 +- .../http/HttpPipelinedRequest.java | 10 +- .../org/elasticsearch/http/HttpRequest.java | 65 ++ .../org/elasticsearch/http/HttpResponse.java | 32 + .../rest/AbstractRestChannel.java | 2 +- .../elasticsearch/rest/RestController.java | 5 +- .../org/elasticsearch/rest/RestRequest.java | 102 +-- .../org/elasticsearch/rest/RestResponse.java | 14 +- .../AbstractHttpServerTransportTests.java | 93 +++ .../http/DefaultRestChannelTests.java | 444 +++++++++++++ .../rest/BytesRestResponseTests.java | 24 +- .../rest/RestControllerTests.java | 94 +-- .../elasticsearch/rest/RestRequestTests.java | 129 ++-- .../test/rest/FakeRestRequest.java | 143 +++- .../core/security/rest/RestRequestFilter.java | 26 +- .../security/audit/index/IndexAuditTrail.java | 15 +- .../audit/logfile/LoggingAuditTrail.java | 10 +- .../xpack/security/rest/RemoteHostHeader.java | 2 +- .../security/rest/SecurityRestFilter.java | 11 +- .../SecurityNetty4HttpServerTransport.java | 2 +- .../audit/index/IndexAuditTrailTests.java | 5 +- .../security/rest/RestRequestFilterTests.java | 2 +- .../rest/SecurityRestFilterTests.java | 2 + 51 files changed, 2111 insertions(+), 2277 deletions(-) create mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java delete mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java delete mode 100644 plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java create mode 100644 server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpChannel.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpRequest.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpResponse.java create mode 100644 server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java index cb31d444544..473985d2109 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java @@ -19,252 +19,58 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.net.InetSocketAddress; -final class Netty4HttpChannel extends AbstractRestChannel { +public class Netty4HttpChannel implements HttpChannel { - private final Netty4HttpServerTransport transport; private final Channel channel; - private final FullHttpRequest nettyRequest; - private final int sequence; - private final ThreadContext threadContext; - private final HttpHandlingSettings handlingSettings; - /** - * @param transport The corresponding NettyHttpServerTransport where this channel belongs to. - * @param request The request that is handled by this channel. - * @param sequence The pipelining sequence number for this request - * @param handlingSettings true if error messages should include stack traces. - * @param threadContext the thread context for the channel - */ - Netty4HttpChannel(Netty4HttpServerTransport transport, Netty4HttpRequest request, int sequence, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { - super(request, handlingSettings.getDetailedErrorsEnabled()); - this.transport = transport; - this.channel = request.getChannel(); - this.nettyRequest = request.request(); - this.sequence = sequence; - this.threadContext = threadContext; - this.handlingSettings = handlingSettings; + Netty4HttpChannel(Channel channel) { + this.channel = channel; } @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(transport.bigArrays); + public void sendResponse(HttpResponse response, ActionListener listener) { + ChannelPromise writePromise = channel.newPromise(); + writePromise.addListener(f -> { + if (f.isSuccess()) { + listener.onResponse(null); + } else { + final Throwable cause = f.cause(); + Netty4Utils.maybeDie(cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); + } + } + }); + channel.writeAndFlush(response, writePromise); } @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = Netty4Utils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - Netty4CorsHandler.setCorsResponseHeaders(nettyRequest, resp, transport.getCorsConfig()); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - BytesReference content = response.content(); - boolean releaseContent = content instanceof Releasable; - boolean releaseBytesStreamOutput = bytesOutputOrNull() instanceof ReleasableBytesStreamOutput; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - final ChannelPromise promise = channel.newPromise(); - - if (releaseContent) { - promise.addListener(f -> ((Releasable) content).close()); - } - - if (releaseBytesStreamOutput) { - promise.addListener(f -> bytesOutputOrNull().close()); - } - - if (isCloseConnection()) { - promise.addListener(ChannelFutureListener.CLOSE); - } - - Netty4HttpResponse newResponse = new Netty4HttpResponse(sequence, resp); - - channel.writeAndFlush(newResponse, promise); - releaseContent = false; - releaseBytesStreamOutput = false; - } finally { - if (releaseContent) { - ((Releasable) content).close(); - } - if (releaseBytesStreamOutput) { - bytesOutputOrNull().close(); - } - } + public InetSocketAddress getLocalAddress() { + return (InetSocketAddress) channel.localAddress(); } - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); + @Override + public InetSocketAddress getRemoteAddress() { + return (InetSocketAddress) channel.remoteAddress(); } - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } + @Override + public void close() { + channel.close(); } - private void addCookies(HttpResponse resp) { - if (handlingSettings.isResetCookies()) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); - } - } - } - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); - } - - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; - } - - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); - } - - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public Channel getNettyChannel() { + return channel; } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 12c2e9a6857..e6436ccea1a 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -66,7 +66,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler { try { List> readyResponses = aggregator.write(response, promise); for (Tuple readyResponse : readyResponses) { - ctx.write(readyResponse.v1().getResponse(), readyResponse.v2()); + ctx.write(readyResponse.v1(), readyResponse.v2()); } success = true; } catch (IllegalStateException e) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 2ce6ffada67..ffabe5cbbe2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -19,17 +19,22 @@ package org.elasticsearch.http.netty4; -import io.netty.channel.Channel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.net.SocketAddress; import java.util.AbstractMap; import java.util.Collection; import java.util.Collections; @@ -38,25 +43,16 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -public class Netty4HttpRequest extends RestRequest { - +public class Netty4HttpRequest implements HttpRequest { private final FullHttpRequest request; - private final Channel channel; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - /** - * Construct a new request. - * - * @param xContentRegistry the content registry - * @param request the underlying request - * @param channel the channel for the request - * @throws BadParameterException if the parameters can not be decoded - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); + Netty4HttpRequest(FullHttpRequest request, int sequence) { this.request = request; - this.channel = channel; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = Netty4Utils.toBytesReference(request.content()); } else { @@ -64,71 +60,39 @@ public class Netty4HttpRequest extends RestRequest { } } - /** - * Construct a new request. In contrast to - * {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so - * this constructor will not throw a {@link BadParameterException}. - * - * @param xContentRegistry the content registry - * @param params the parameters for the request - * @param uri the path for the request - * @param request the underlying request - * @param channel the channel for the request - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String uri, - final FullHttpRequest request, - final Channel channel) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); - this.request = request; - this.channel = channel; - if (request.content().isReadable()) { - this.content = Netty4Utils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - } - - public FullHttpRequest request() { - return this.request; - } - @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -139,40 +103,64 @@ public class Netty4HttpRequest extends RestRequest { return request.uri(); } - @Override - public boolean hasContent() { - return content.length() > 0; - } - @Override public BytesReference content() { return content; } - /** - * Returns the remote address where this rest request channel is "connected to". The - * returned {@link SocketAddress} is supposed to be down-cast into more - * concrete type such as {@link java.net.InetSocketAddress} to retrieve - * the detailed information. - */ + @Override - public SocketAddress getRemoteAddress() { - return channel.remoteAddress(); + public final Map> getHeaders() { + return headers; } - /** - * Returns the local address where this request channel is bound to. The returned - * {@link SocketAddress} is supposed to be down-cast into more concrete - * type such as {@link java.net.InetSocketAddress} to retrieve the detailed - * information. - */ @Override - public SocketAddress getLocalAddress() { - return channel.localAddress(); + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); } - public Channel getChannel() { - return channel; + @Override + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } + } + + @Override + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new Netty4HttpRequest(requestWithoutHeader, sequence); + } + + @Override + public Netty4HttpResponse createResponse(RestStatus status, BytesReference content) { + return new Netty4HttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { + return request; + } + + int sequence() { + return sequence; } /** @@ -249,7 +237,7 @@ public class Netty4HttpRequest extends RestRequest { @Override public Set>> entrySet() { return httpHeaders.names().stream().map(k -> new AbstractMap.SimpleImmutableEntry<>(k, httpHeaders.getAll(k))) - .collect(Collectors.toSet()); + .collect(Collectors.toSet()); } } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java index c3a010226a4..4547a63a9a2 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java @@ -20,112 +20,51 @@ package org.elasticsearch.http.netty4; import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaders; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpPipelinedRequest; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; - @ChannelHandler.Sharable class Netty4HttpRequestHandler extends SimpleChannelInboundHandler> { private final Netty4HttpServerTransport serverTransport; - private final HttpHandlingSettings handlingSettings; - private final ThreadContext threadContext; - Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { + Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport) { this.serverTransport = serverTransport; - this.handlingSettings = handlingSettings; - this.threadContext = threadContext; } @Override protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest msg) throws Exception { - final FullHttpRequest request = msg.getRequest(); + Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); + FullHttpRequest request = msg.getRequest(); try { + final FullHttpRequest copiedRequest = + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + Unpooled.copiedBuffer(request.content()), + request.headers(), + request.trailingHeaders()); - final FullHttpRequest copy = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - Unpooled.copiedBuffer(request.content()), - request.headers(), - request.trailingHeaders()); - - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final Netty4HttpRequest httpRequest; - { - Netty4HttpRequest innerHttpRequest; - try { - innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel()); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copy, ctx.channel()); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these - * parameter values. - */ - final Netty4HttpChannel channel; - { - Netty4HttpChannel innerChannel; - try { - innerChannel = - new Netty4HttpChannel(serverTransport, httpRequest, msg.getSequence(), handlingSettings, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final Netty4HttpRequest innerRequest = - new Netty4HttpRequest( - serverTransport.xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copy.uri(), - copy, - ctx.channel()); - innerChannel = - new Netty4HttpChannel(serverTransport, innerRequest, msg.getSequence(), handlingSettings, threadContext); - } - channel = innerChannel; - } + Netty4HttpRequest httpRequest = new Netty4HttpRequest(copiedRequest, msg.getSequence()); if (request.decoderResult().isFailure()) { - serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + serverTransport.incomingRequestError(httpRequest, channel, new Exception(cause)); + } else { + serverTransport.incomingRequestError(httpRequest, channel, (Exception) cause); + } } else { - serverTransport.dispatchRequest(httpRequest, channel); + serverTransport.incomingRequest(httpRequest, channel); } } finally { // As we have copied the buffer, we can release the request @@ -133,32 +72,6 @@ class Netty4HttpRequestHandler extends SimpleChannelInboundHandler MAP; + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + } + } + diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 0e18232e01c..6bfd8168dbe 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -39,6 +39,7 @@ import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.util.AttributeKey; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.common.Strings; @@ -53,9 +54,7 @@ import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.HttpHandlingSettings; @@ -149,38 +148,29 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public static final Setting SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE = Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); - protected final BigArrays bigArrays; + private final ByteSizeValue maxInitialLineLength; + private final ByteSizeValue maxHeaderSize; + private final ByteSizeValue maxChunkSize; - protected final ByteSizeValue maxInitialLineLength; - protected final ByteSizeValue maxHeaderSize; - protected final ByteSizeValue maxChunkSize; + private final int workerCount; - protected final int workerCount; + private final int pipeliningMaxEvents; - protected final int pipeliningMaxEvents; + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean reuseAddress; - /** - * The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}. - */ - protected final NamedXContentRegistry xContentRegistry; - - protected final boolean tcpNoDelay; - protected final boolean tcpKeepAlive; - protected final boolean reuseAddress; - - protected final ByteSizeValue tcpSendBufferSize; - protected final ByteSizeValue tcpReceiveBufferSize; - protected final RecvByteBufAllocator recvByteBufAllocator; + private final ByteSizeValue tcpSendBufferSize; + private final ByteSizeValue tcpReceiveBufferSize; + private final RecvByteBufAllocator recvByteBufAllocator; private final int readTimeoutMillis; - protected final int maxCompositeBufferComponents; + private final int maxCompositeBufferComponents; protected volatile ServerBootstrap serverBootstrap; protected final List serverChannels = new ArrayList<>(); - protected final HttpHandlingSettings httpHandlingSettings; - // package private for testing Netty4OpenChannelsHandler serverOpenChannels; @@ -189,16 +179,13 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); - this.bigArrays = bigArrays; - this.xContentRegistry = xContentRegistry; this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings); this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings); @@ -398,26 +385,27 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { } public ChannelHandler configureServerChannelHandler() { - return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext()); + return new HttpChannelHandler(this, handlingSettings); } + static final AttributeKey HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel"); + protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; private final Netty4HttpRequestHandler requestHandler; private final HttpHandlingSettings handlingSettings; - protected HttpChannelHandler( - final Netty4HttpServerTransport transport, - final HttpHandlingSettings handlingSettings, - final ThreadContext threadContext) { + protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) { this.transport = transport; this.handlingSettings = handlingSettings; - this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext); + this.requestHandler = new Netty4HttpRequestHandler(transport); } @Override protected void initChannel(Channel ch) throws Exception { + Netty4HttpChannel nettyTcpChannel = new Netty4HttpChannel(ch); + ch.attr(HTTP_CHANNEL_KEY).set(nettyTcpChannel); ch.pipeline().addLast("openChannels", transport.serverOpenChannels); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); final HttpRequestDecoder decoder = new HttpRequestDecoder( diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java index 779eb4fe2e4..38d832d6080 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java @@ -22,6 +22,7 @@ package org.elasticsearch.http.netty4.cors; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.netty4.Netty4HttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public class Netty4CorsHandler extends ChannelDuplexHandler { ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass(); + Netty4HttpResponse response = (Netty4HttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise);; + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index f4818a2e567..466c4b68bfa 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -333,10 +333,10 @@ public class Netty4Transport extends TcpTransport { addClosedExceptionLogger(ch); NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name); ch.attr(CHANNEL_KEY).set(nettyTcpChannel); - serverAcceptedChannel(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name)); + serverAcceptedChannel(nettyTcpChannel); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java index f650e757e7a..89fabdcd763 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java @@ -98,8 +98,11 @@ public class NettyTcpChannel implements TcpChannel { } else { final Throwable cause = f.cause(); Netty4Utils.maybeDie(cause); - assert cause instanceof Exception; - listener.onFailure((Exception) cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); + } } }); channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java new file mode 100644 index 00000000000..15a0850f64d --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java @@ -0,0 +1,148 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class Netty4CorsTests extends ESTestCase { + + public void testCorsEnabledWithoutAllowOrigins() { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() { + final String originValue = Netty4CorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { + // construct request and send it over the transport layer + final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + embeddedChannel.pipeline().addLast(new Netty4CorsHandler(Netty4HttpServerTransport.buildCorsConfig(settings))); + Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest, 0); + embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content"))); + return embeddedChannel.readOutbound(); + } +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java deleted file mode 100644 index 7c5b35a3229..00000000000 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java +++ /dev/null @@ -1,616 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.netty4; - -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelConfig; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelId; -import io.netty.channel.ChannelMetadata; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelProgressivePromise; -import io.netty.channel.ChannelPromise; -import io.netty.channel.EventLoop; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.ReleasablePagedBytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ByteArray; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.NullDispatcher; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.netty4.Netty4Utils; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.SocketAddress; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; - -public class Netty4HttpChannelTests extends ESTestCase { - - private NetworkService networkService; - private ThreadPool threadPool; - private MockBigArrays bigArrays; - - @Before - public void setup() throws Exception { - networkService = new NetworkService(Collections.emptyList()); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(Netty4Utils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = Netty4CorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - // send a response - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - HttpResponse response = ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - } - - public void testReleaseOnSendToClosedChannel() { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse response = new TestResponse(bigArrays); - assertThat(response.content(), instanceOf(Releasable.class)); - embeddedChannel.close(); - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testReleaseOnSendToChannelAfterException() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, embeddedChannel); - - // send a response, the channel close status should match - assertTrue(embeddedChannel.isOpen()); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - assertThat(embeddedChannel.isOpen(), equalTo(!close)); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = - new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - return ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - } - } - - private static class WriteCapturingChannel implements Channel { - - private List writtenObjects = new ArrayList<>(); - - @Override - public ChannelId id() { - return null; - } - - @Override - public EventLoop eventLoop() { - return null; - } - - @Override - public Channel parent() { - return null; - } - - @Override - public ChannelConfig config() { - return null; - } - - @Override - public boolean isOpen() { - return false; - } - - @Override - public boolean isRegistered() { - return false; - } - - @Override - public boolean isActive() { - return false; - } - - @Override - public ChannelMetadata metadata() { - return null; - } - - @Override - public SocketAddress localAddress() { - return null; - } - - @Override - public SocketAddress remoteAddress() { - return null; - } - - @Override - public ChannelFuture closeFuture() { - return null; - } - - @Override - public boolean isWritable() { - return false; - } - - @Override - public long bytesBeforeUnwritable() { - return 0; - } - - @Override - public long bytesBeforeWritable() { - return 0; - } - - @Override - public Unsafe unsafe() { - return null; - } - - @Override - public ChannelPipeline pipeline() { - return null; - } - - @Override - public ByteBufAllocator alloc() { - return null; - } - - @Override - public Channel read() { - return null; - } - - @Override - public Channel flush() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture disconnect() { - return null; - } - - @Override - public ChannelFuture close() { - return null; - } - - @Override - public ChannelFuture deregister() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture disconnect(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture close(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture deregister(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture write(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture write(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelPromise newPromise() { - return null; - } - - @Override - public ChannelProgressivePromise newProgressivePromise() { - return null; - } - - @Override - public ChannelFuture newSucceededFuture() { - return null; - } - - @Override - public ChannelFuture newFailedFuture(Throwable cause) { - return null; - } - - @Override - public ChannelPromise voidPromise() { - return null; - } - - @Override - public Attribute attr(AttributeKey key) { - return null; - } - - @Override - public boolean hasAttr(AttributeKey key) { - return false; - } - - @Override - public int compareTo(Channel o) { - return 0; - } - - List getWrittenObjects() { - return writtenObjects; - } - - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = Netty4Utils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - TestResponse(final BigArrays bigArrays) { - final byte[] bytes; - try { - bytes = "content".getBytes("UTF-8"); - } catch (final UnsupportedEncodingException e) { - throw new AssertionError(e); - } - final ByteArray bigArray = bigArrays.newByteArray(bytes.length); - bigArray.set(0, bytes, 0, bytes.length); - reference = new ReleasablePagedBytesReference(bigArrays, bigArray, bytes.length, Releasables.releaseOnce(bigArray)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } - -} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index f6c5dfd5a50..8b3ba19fe01 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -191,11 +190,11 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - Netty4HttpResponse resp = new Netty4HttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -233,10 +232,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -246,9 +245,10 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -260,7 +260,7 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase { waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java index f2b28b90918..3101f660d05 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -26,22 +26,20 @@ import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -120,7 +118,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { @Override public ChannelHandler configureServerChannelHandler() { - return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext()); + return new CustomHttpChannelHandler(this, executorService); } @Override @@ -135,8 +133,8 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { private final ExecutorService executorService; - CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) { - super(transport, transport.httpHandlingSettings, threadContext); + CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService) { + super(transport, transport.handlingSettings); this.executorService = executorService; } @@ -187,8 +185,9 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); - httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); + Netty4HttpRequest httpRequest = new Netty4HttpRequest(fullHttpRequest, pipelinedRequest.getSequence()); + Netty4HttpResponse response = httpRequest.createResponse(RestStatus.OK, new BytesArray(uri.getBytes(StandardCharsets.UTF_8))); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); final boolean slow = uri.matches("/slow/\\d+"); if (slow) { @@ -202,7 +201,7 @@ public class Netty4HttpServerPipeliningTests extends ESTestCase { } final ChannelPromise promise = ctx.newPromise(); - ctx.writeAndFlush(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.writeAndFlush(response, promise); } } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 5b22409b92d..bcf28506143 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -291,40 +291,6 @@ public class Netty4HttpServerTransportTests extends ESTestCase { assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (Netty4HttpServerTransport transport = - new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - public void testReadTimeout() throws Exception { final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index 05f28e8254a..ea75c62dbbc 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -23,54 +23,38 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandler; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; -import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; -import org.elasticsearch.rest.RestRequest; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.function.BiConsumer; - public class HttpReadWriteHandler implements ReadWriteHandler { private final NettyAdaptor adaptor; - private final NioSocketChannel nioChannel; + private final NioHttpChannel nioHttpChannel; private final NioHttpServerTransport transport; - private final HttpHandlingSettings settings; - private final NamedXContentRegistry xContentRegistry; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, - NamedXContentRegistry xContentRegistry, NioCorsConfig corsConfig, ThreadContext threadContext) { - this.nioChannel = nioChannel; + HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, + NioCorsConfig corsConfig) { + this.nioHttpChannel = nioHttpChannel; this.transport = transport; - this.settings = settings; - this.xContentRegistry = xContentRegistry; - this.corsConfig = corsConfig; - this.threadContext = threadContext; List handlers = new ArrayList<>(5); HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(), @@ -89,7 +73,7 @@ public class HttpReadWriteHandler implements ReadWriteHandler { handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents())); adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); - adaptor.addCloseListener((v, e) -> nioChannel.close()); + adaptor.addCloseListener((v, e) -> nioHttpChannel.close()); } @Override @@ -150,95 +134,22 @@ public class HttpReadWriteHandler implements ReadWriteHandler { request.headers(), request.trailingHeaders()); - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final NioHttpRequest httpRequest; - { - NioHttpRequest innerHttpRequest; - try { - innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copiedRequest); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of - * these parameter values. - */ - final NioHttpChannel channel; - { - NioHttpChannel innerChannel; - int sequence = pipelinedRequest.getSequence(); - BigArrays bigArrays = transport.getBigArrays(); - try { - innerChannel = new NioHttpChannel(nioChannel, bigArrays, httpRequest, sequence, settings, corsConfig, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final NioHttpRequest innerRequest = - new NioHttpRequest( - xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copiedRequest.uri(), - copiedRequest); - innerChannel = new NioHttpChannel(nioChannel, bigArrays, innerRequest, sequence, settings, corsConfig, threadContext); - } - channel = innerChannel; - } + NioHttpRequest httpRequest = new NioHttpRequest(copiedRequest, pipelinedRequest.getSequence()); if (request.decoderResult().isFailure()) { - transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - transport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + transport.incomingRequestError(httpRequest, nioHttpChannel, new Exception(cause)); + } else { + transport.incomingRequestError(httpRequest, nioHttpChannel, (Exception) cause); + } } else { - transport.dispatchRequest(httpRequest, channel); + transport.incomingRequest(httpRequest, nioHttpChannel); } } finally { // As we have copied the buffer, we can release the request request.release(); } } - - private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) { - final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); - headersWithoutContentTypeHeader.add(request.headers()); - headersWithoutContentTypeHeader.remove("Content-Type"); - final FullHttpRequest requestWithoutContentTypeHeader = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - request.content(), - headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again - request.trailingHeaders()); // Content-Type can not be a trailing header - try { - return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader); - } catch (final RestRequest.BadParameterException e) { - badRequestCause.addSuppressed(e); - return requestWithoutParameters(requestWithoutContentTypeHeader); - } - } - - private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) { - // remove all parameters as at least one is incorrectly encoded - return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request); - } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java index 634421b34ea..088f0e85dde 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java @@ -19,244 +19,21 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.Cookie; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import java.util.ArrayList; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.BiConsumer; +import java.io.IOException; +import java.nio.channels.SocketChannel; -public class NioHttpChannel extends AbstractRestChannel { +public class NioHttpChannel extends NioSocketChannel implements HttpChannel { - private final BigArrays bigArrays; - private final int sequence; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - private final FullHttpRequest nettyRequest; - private final NioSocketChannel nioChannel; - private final boolean resetCookies; - - NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, int sequence, - HttpHandlingSettings settings, NioCorsConfig corsConfig, ThreadContext threadContext) { - super(request, settings.getDetailedErrorsEnabled()); - this.nioChannel = nioChannel; - this.bigArrays = bigArrays; - this.sequence = sequence; - this.corsConfig = corsConfig; - this.threadContext = threadContext; - this.nettyRequest = request.getRequest(); - this.resetCookies = settings.isResetCookies(); + NioHttpChannel(SocketChannel socketChannel) throws IOException { + super(socketChannel); } - @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = ByteBufUtils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - ArrayList toClose = new ArrayList<>(3); - - boolean success = false; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - BytesReference content = response.content(); - if (content instanceof Releasable) { - toClose.add((Releasable) content); - } - BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); - if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { - toClose.add((Releasable) bytesStreamOutput); - } - - if (isCloseConnection()) { - toClose.add(nioChannel::close); - } - - BiConsumer listener = (aVoid, ex) -> Releasables.close(toClose); - nioChannel.getContext().sendMessage(new NioHttpResponse(sequence, resp), listener); - success = true; - } finally { - if (success == false) { - Releasables.close(toClose); - } - } - } - - @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(bigArrays); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } - } - - private void addCookies(HttpResponse resp) { - if (resetCookies) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); - } - } - } - } - - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); - } - - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); - } - - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public void sendResponse(HttpResponse response, ActionListener listener) { + getContext().sendMessage(response, ActionListener.toBiConsumer(listener)); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java index 1eb63364f99..977092ddac0 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java @@ -68,7 +68,7 @@ public class NioHttpPipeliningHandler extends ChannelDuplexHandler { List> readyResponses = aggregator.write(response, listener); success = true; for (Tuple responseToWrite : readyResponses) { - ctx.write(responseToWrite.v1().getResponse(), responseToWrite.v2()); + ctx.write(responseToWrite.v1(), responseToWrite.v2()); } } catch (IllegalStateException e) { ctx.channel().close(); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java index 4dcd6ba19e0..08937593f3b 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java @@ -19,13 +19,20 @@ package org.elasticsearch.http.nio; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import java.util.AbstractMap; import java.util.Collection; @@ -35,25 +42,17 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -public class NioHttpRequest extends RestRequest { +public class NioHttpRequest implements HttpRequest { private final FullHttpRequest request; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); - this.request = request; - if (request.content().isReadable()) { - this.content = ByteBufUtils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - - } - - NioHttpRequest(NamedXContentRegistry xContentRegistry, Map params, String uri, FullHttpRequest request) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); + NioHttpRequest(FullHttpRequest request, int sequence) { this.request = request; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = ByteBufUtils.toBytesReference(request.content()); } else { @@ -62,38 +61,38 @@ public class NioHttpRequest extends RestRequest { } @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -104,20 +103,66 @@ public class NioHttpRequest extends RestRequest { return request.uri(); } - @Override - public boolean hasContent() { - return content.length() > 0; - } - @Override public BytesReference content() { return content; } - public FullHttpRequest getRequest() { + + @Override + public final Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } + } + + @Override + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new NioHttpRequest(requestWithoutHeader, sequence); + } + + @Override + public NioHttpResponse createResponse(RestStatus status, BytesReference content) { + return new NioHttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { return request; } + int sequence() { + return sequence; + } + /** * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications * and due to the underlying implementation, it performs case insensitive lookups of key to values. diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java index 4b634994b45..24de843dcc8 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java @@ -19,19 +19,100 @@ package org.elasticsearch.http.nio; -import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedMessage; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.rest.RestStatus; -public class NioHttpResponse extends HttpPipelinedMessage { +import java.util.Collections; +import java.util.EnumMap; +import java.util.Map; - private final FullHttpResponse response; +public class NioHttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage { - public NioHttpResponse(int sequence, FullHttpResponse response) { - super(sequence); - this.response = response; + private final int sequence; + private final NioHttpRequest request; + + NioHttpResponse(NioHttpRequest request, RestStatus status, BytesReference content) { + super(request.nettyRequest().protocolVersion(), getStatus(status), ByteBufUtils.toByteBuf(content)); + this.sequence = request.sequence(); + this.request = request; } - public FullHttpResponse getResponse() { - return response; + @Override + public void addHeader(String name, String value) { + headers().add(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers().contains(name); + } + + @Override + public int getSequence() { + return sequence; + } + + private static Map MAP; + + public NioHttpRequest getRequest() { + return request; + } + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 57aaebb16a1..5aac491a6ab 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -42,7 +42,6 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; -import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.HttpStats; import org.elasticsearch.http.nio.cors.NioCorsConfig; @@ -53,11 +52,11 @@ import org.elasticsearch.nio.EventHandler; import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannel; import org.elasticsearch.nio.NioGroup; +import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.nio.NioSelector; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -104,12 +103,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), (s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope); - private final BigArrays bigArrays; - private final ThreadPool threadPool; - private final NamedXContentRegistry xContentRegistry; - - private final HttpHandlingSettings httpHandlingSettings; - private final boolean tcpNoDelay; private final boolean tcpKeepAlive; private final boolean reuseAddress; @@ -124,16 +117,12 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); - this.bigArrays = bigArrays; - this.threadPool = threadPool; - this.xContentRegistry = xContentRegistry; + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);; this.corsConfig = buildCorsConfig(settings); this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); @@ -148,10 +137,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents); } - BigArrays getBigArrays() { - return bigArrays; - } - public Logger getLogger() { return logger; } @@ -335,17 +320,17 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { socketChannels.add(socketChannel); } - private class HttpChannelFactory extends ChannelFactory { + private class HttpChannelFactory extends ChannelFactory { private HttpChannelFactory() { super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); } @Override - public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { - NioSocketChannel nioChannel = new NioSocketChannel(channel); + public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + NioHttpChannel nioChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this, - httpHandlingSettings, xContentRegistry, corsConfig, threadPool.getThreadContext()); + handlingSettings, corsConfig); Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline, InboundChannelBuffer.allocatingInstance()); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java index 63585107037..98ae2d523ca 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java @@ -22,6 +22,7 @@ package org.elasticsearch.http.nio.cors; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.nio.NioHttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public class NioCorsHandler extends ChannelDuplexHandler { ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass(); + NioHttpResponse response = (NioHttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise); + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 6ad53521ee1..5bda7e1b83d 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -23,29 +23,31 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; -import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder; +import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -55,6 +57,9 @@ import java.nio.ByteBuffer; import java.util.List; import java.util.function.BiConsumer; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; @@ -64,7 +69,12 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_HEAD import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -72,7 +82,7 @@ import static org.mockito.Mockito.verify; public class HttpReadWriteHandlerTests extends ESTestCase { private HttpReadWriteHandler handler; - private NioSocketChannel nioSocketChannel; + private NioHttpChannel nioHttpChannel; private NioHttpServerTransport transport; private final RequestEncoder requestEncoder = new RequestEncoder(); @@ -96,15 +106,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase { SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings), SETTING_PIPELINING_MAX_EVENTS.getDefault(settings), SETTING_CORS_ENABLED.getDefault(settings)); - ThreadContext threadContext = new ThreadContext(settings); - nioSocketChannel = mock(NioSocketChannel.class); - handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY, - NioCorsConfigBuilder.forAnyOrigin().build(), threadContext); + nioHttpChannel = mock(NioHttpChannel.class); + handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build()); } public void testSuccessfulDecodeHttpRequest() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); int slicePoint = randomInt(buf.writerIndex() - 1); @@ -113,22 +121,21 @@ public class HttpReadWriteHandlerTests extends ESTestCase { ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); handler.consumeReads(toChannelBuffer(slicedBuf2)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest(); - assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion()); - assertEquals(httpRequest.method(), nettyHttpRequest.method()); + HttpRequest nioHttpRequest = requestCaptor.getValue(); + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + assertEquals(RestRequest.Method.GET, nioHttpRequest.method()); } public void testDecodeHttpRequestError() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); buf.setByte(0, ' '); @@ -137,15 +144,15 @@ public class HttpReadWriteHandlerTests extends ESTestCase { handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); - verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture()); assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); } public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); HttpUtil.setContentLength(httpRequest, 1025); HttpUtil.setKeepAlive(httpRequest, false); @@ -153,60 +160,176 @@ public class HttpReadWriteHandlerTests extends ESTestCase { handler.consumeReads(toChannelBuffer(buf)); - verify(transport, times(0)).dispatchBadRequest(any(), any(), any()); - verify(transport, times(0)).dispatchRequest(any(), any()); + verify(transport, times(0)).incomingRequestError(any(), any(), any()); + verify(transport, times(0)).incomingRequest(any(), any()); List flushOperations = handler.pollFlushOperations(); assertFalse(flushOperations.isEmpty()); FlushOperation flushOperation = flushOperations.get(0); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); flushOperation.getListener().accept(null, null); // Since we have keep-alive set to false, we should close the channel after the response has been // flushed - verify(nioSocketChannel).close(); + verify(nioHttpChannel).close(); } @SuppressWarnings("unchecked") public void testEncodeHttpResponse() throws IOException { prepareHandlerForResponse(handler); - FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); - NioHttpResponse pipelinedResponse = new NioHttpResponse(0, fullHttpResponse); + DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); + httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0"); SocketChannelContext context = mock(SocketChannelContext.class); - HttpWriteOperation writeOperation = new HttpWriteOperation(context, pipelinedResponse, mock(BiConsumer.class)); + HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class)); List flushOperations = handler.writeToBytes(writeOperation); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); assertEquals(HttpResponseStatus.OK, response.status()); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); } - private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException { - HttpMethod method = HttpMethod.GET; - HttpVersion version = HttpVersion.HTTP_1_1; + public void testCorsEnabledWithoutAllowOrigins() throws IOException { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() throws IOException { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() throws IOException { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + FullHttpResponse response = executeCorsRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() throws IOException { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() throws IOException { + final String originValue = NioCorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException { + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings); + HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig); + prepareHandlerForResponse(handler); + DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + NioHttpRequest nioHttpRequest = new NioHttpRequest(httpRequest, 0); + BytesArray content = new BytesArray("content"); + HttpResponse response = nioHttpRequest.createResponse(RestStatus.OK, content); + response.addHeader("Content-Length", Integer.toString(content.length())); + + SocketChannelContext context = mock(SocketChannelContext.class); + List flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {})); + + FlushOperation flushOperation = flushOperations.get(0); + return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + } + + + + private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException { + HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD; + HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1; String uri = "http://localhost:9090/" + randomAlphaOfLength(8); - HttpRequest request = new DefaultFullHttpRequest(version, method, uri); + io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri); ByteBuf buf = requestEncoder.encode(request); handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class); + verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest requestParsed = nioHttpRequest.getRequest(); - assertNotNull(requestParsed); - assertEquals(requestParsed.method(), method); - assertEquals(requestParsed.protocolVersion(), version); - assertEquals(requestParsed.uri(), uri); - return requestParsed; + NioHttpRequest nioHttpRequest = requestCaptor.getValue(); + assertNotNull(nioHttpRequest); + assertEquals(method.name(), nioHttpRequest.method().name()); + if (version == HttpVersion.HTTP_1_1) { + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + } else { + assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion()); + } + assertEquals(nioHttpRequest.uri(), uri); + return nioHttpRequest; } private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { @@ -226,11 +349,13 @@ public class HttpReadWriteHandlerTests extends ESTestCase { return buffer; } + private static final int MAX = 16 * 1024 * 1024; + private static class RequestEncoder { - private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder()); + private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder(), new HttpObjectAggregator(MAX)); - private ByteBuf encode(HttpRequest httpRequest) { + private ByteBuf encode(io.netty.handler.codec.http.HttpRequest httpRequest) { requestEncoder.writeOutbound(httpRequest); return requestEncoder.readOutbound(); } @@ -238,9 +363,9 @@ public class HttpReadWriteHandlerTests extends ESTestCase { private static class ResponseDecoder { - private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder()); + private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(MAX)); - private HttpResponse decode(ByteBuf response) { + private FullHttpResponse decode(ByteBuf response) { responseDecoder.writeInbound(response); return responseDecoder.readInbound(); } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java deleted file mode 100644 index 5fa0a7ae0a6..00000000000 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.nio; - -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.charset.StandardCharsets; -import java.util.function.BiConsumer; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class NioHttpChannelTests extends ESTestCase { - - private ThreadPool threadPool; - private MockBigArrays bigArrays; - private NioSocketChannel nioChannel; - private SocketChannelContext channelContext; - - @Before - public void setup() throws Exception { - nioChannel = mock(NioSocketChannel.class); - channelContext = mock(SocketChannelContext.class); - when(nioChannel.getContext()).thenReturn(channelContext); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(ByteBufUtils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = NioCorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - // send a response - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, corsConfig, - threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext).sendMessage(responseCaptor.capture(), any()); - Object nioResponse = responseCaptor.getValue(); - HttpResponse response = ((NioHttpResponse) nioResponse).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - - @SuppressWarnings("unchecked") - public void testReleaseInListener() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final NioHttpRequest request = new NioHttpRequest(registry, httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - if (randomBoolean()) { - listener.accept(null, null); - } else { - listener.accept(null, new ClosedChannelException()); - } - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - - - @SuppressWarnings("unchecked") - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - listener.accept(null, null); - if (close) { - verify(nioChannel, times(1)).close(); - } else { - verify(nioChannel, times(0)).close(); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, httpHandlingSettings, corsConfig, - threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext, atLeastOnce()).sendMessage(responseCaptor.capture(), any()); - return ((NioHttpResponse) responseCaptor.getValue()).getResponse(); - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = ByteBufUtils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } -} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java index 94d7db171a5..5f2784a3567 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.HttpRequest; import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -190,11 +189,11 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - NioHttpResponse resp = new NioHttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -231,10 +230,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -244,9 +243,10 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -258,7 +258,7 @@ public class NioHttpPipeliningHandlerTests extends ESTestCase { waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new NioHttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java index c43fc7d0723..48a5bf617a4 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java @@ -280,40 +280,6 @@ public class NioHttpServerTransportTests extends ESTestCase { assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (NioHttpServerTransport transport = - new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - // public void testReadTimeout() throws Exception { // final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { // diff --git a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java index c75754bde58..4fad4159f55 100644 --- a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java @@ -21,6 +21,7 @@ package org.elasticsearch.http; import com.carrotsearch.hppc.IntHashSet; import com.carrotsearch.hppc.IntSet; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.network.NetworkService; @@ -29,7 +30,9 @@ import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.PortsRange; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.threadpool.ThreadPool; @@ -48,11 +51,14 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PORT; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; -public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport { +public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport { + public final HttpHandlingSettings handlingSettings; protected final NetworkService networkService; + protected final BigArrays bigArrays; protected final ThreadPool threadPool; protected final Dispatcher dispatcher; + private final NamedXContentRegistry xContentRegistry; protected final String[] bindHosts; protected final String[] publishHosts; @@ -61,11 +67,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected volatile BoundTransportAddress boundAddress; - protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) { + protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { super(settings); this.networkService = networkService; + this.bigArrays = bigArrays; this.threadPool = threadPool; + this.xContentRegistry = xContentRegistry; this.dispatcher = dispatcher; + this.handlingSettings = HttpHandlingSettings.fromSettings(settings); // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); @@ -156,17 +166,94 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo return publishPort; } - public void dispatchRequest(final RestRequest request, final RestChannel channel) { + /** + * This method handles an incoming http request. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + */ + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { + handleIncomingRequest(httpRequest, httpChannel, null); + } + + /** + * This method handles an incoming http request that has encountered an error. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + * @param exception that was encountered + */ + public void incomingRequestError(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + handleIncomingRequest(httpRequest, httpChannel, exception); + } + + // Visible for testing + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchRequest(request, channel, threadContext); + if (badRequestCause != null) { + dispatcher.dispatchBadRequest(restRequest, channel, threadContext, badRequestCause); + } else { + dispatcher.dispatchRequest(restRequest, channel, threadContext); + } } } - public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchBadRequest(request, channel, threadContext, cause); + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + Exception badRequestCause = exception; + + /* + * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there + * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we + * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, + * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the + * underlying exception that caused us to treat the request as bad. + */ + final RestRequest restRequest; + { + RestRequest innerRestRequest; + try { + innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel); + } catch (final RestRequest.ContentTypeHeaderException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause); + } catch (final RestRequest.BadParameterException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + } + restRequest = innerRestRequest; + } + + /* + * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid + * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an + * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these + * parameter values. + */ + final RestChannel channel; + { + RestChannel innerChannel; + ThreadContext threadContext = threadPool.getThreadContext(); + try { + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext); + } catch (final IllegalArgumentException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext); + } + channel = innerChannel; + } + + dispatchRequest(restRequest, channel, badRequestCause); + } + + private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) { + HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); + try { + return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel); + } catch (final RestRequest.BadParameterException e) { + badRequestCause.addSuppressed(e); + return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } } } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java new file mode 100644 index 00000000000..f5924bb239e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -0,0 +1,172 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * The default rest channel for incoming requests. This class implements the basic logic for sending a rest + * response. It will set necessary headers nad ensure that bytes are released after the response is sent. + */ +public class DefaultRestChannel extends AbstractRestChannel implements RestChannel { + + static final String CLOSE = "close"; + static final String CONNECTION = "connection"; + static final String KEEP_ALIVE = "keep-alive"; + static final String CONTENT_TYPE = "content-type"; + static final String CONTENT_LENGTH = "content-length"; + static final String SET_COOKIE = "set-cookie"; + static final String X_OPAQUE_ID = "X-Opaque-Id"; + + private final HttpRequest httpRequest; + private final BigArrays bigArrays; + private final HttpHandlingSettings settings; + private final ThreadContext threadContext; + private final HttpChannel httpChannel; + + DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays, + HttpHandlingSettings settings, ThreadContext threadContext) { + super(request, settings.getDetailedErrorsEnabled()); + this.httpChannel = httpChannel; + this.httpRequest = httpRequest; + this.bigArrays = bigArrays; + this.settings = settings; + this.threadContext = threadContext; + } + + @Override + protected BytesStreamOutput newBytesOutput() { + return new ReleasableBytesStreamOutput(bigArrays); + } + + @Override + public void sendResponse(RestResponse restResponse) { + HttpResponse httpResponse; + if (RestRequest.Method.HEAD == request.method()) { + httpResponse = httpRequest.createResponse(restResponse.status(), BytesArray.EMPTY); + } else { + httpResponse = httpRequest.createResponse(restResponse.status(), restResponse.content()); + } + + // TODO: Ideally we should move the setting of Cors headers into :server + // NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); + + String opaque = request.header(X_OPAQUE_ID); + if (opaque != null) { + setHeaderField(httpResponse, X_OPAQUE_ID, opaque); + } + + // Add all custom headers + addCustomHeaders(httpResponse, restResponse.getHeaders()); + addCustomHeaders(httpResponse, threadContext.getResponseHeaders()); + + ArrayList toClose = new ArrayList<>(3); + + boolean success = false; + try { + // If our response doesn't specify a content-type header, set one + setHeaderField(httpResponse, CONTENT_TYPE, restResponse.contentType(), false); + // If our response has no content-length, calculate and set one + setHeaderField(httpResponse, CONTENT_LENGTH, String.valueOf(restResponse.content().length()), false); + + addCookies(httpResponse); + + BytesReference content = restResponse.content(); + if (content instanceof Releasable) { + toClose.add((Releasable) content); + } + BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); + if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { + toClose.add((Releasable) bytesStreamOutput); + } + + if (isCloseConnection()) { + toClose.add(httpChannel); + } + + ActionListener listener = ActionListener.wrap(() -> Releasables.close(toClose)); + httpChannel.sendResponse(httpResponse, listener); + success = true; + } finally { + if (success == false) { + Releasables.close(toClose); + } + } + + } + + private void setHeaderField(HttpResponse response, String headerField, String value) { + setHeaderField(response, headerField, value, true); + } + + private void setHeaderField(HttpResponse response, String headerField, String value, boolean override) { + if (override || !response.containsHeader(headerField)) { + response.addHeader(headerField, value); + } + } + + private void addCustomHeaders(HttpResponse response, Map> customHeaders) { + if (customHeaders != null) { + for (Map.Entry> headerEntry : customHeaders.entrySet()) { + for (String headerValue : headerEntry.getValue()) { + setHeaderField(response, headerEntry.getKey(), headerValue); + } + } + } + } + + private void addCookies(HttpResponse response) { + if (settings.isResetCookies()) { + List cookies = request.getHttpRequest().strictCookies(); + if (cookies.isEmpty() == false) { + for (String cookie : cookies) { + response.addHeader(SET_COOKIE, cookie); + } + } + } + } + + // Determine if the request connection should be closed on completion. + private boolean isCloseConnection() { + final boolean http10 = isHttp10(); + return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); + } + + // Determine if the request protocol version is HTTP 1.0 + private boolean isHttp10() { + return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + } +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpChannel.java b/server/src/main/java/org/elasticsearch/http/HttpChannel.java new file mode 100644 index 00000000000..baea3e0c3b3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpChannel.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.lease.Releasable; + +import java.net.InetSocketAddress; + +public interface HttpChannel extends Releasable { + + /** + * Sends a http response to the channel. The listener will be executed once the send process has been + * completed. + * + * @param response to send to channel + * @param listener to execute upon send completion + */ + void sendResponse(HttpResponse response, ActionListener listener); + + /** + * Returns the local address for this channel. + * + * @return the local address of this channel. + */ + InetSocketAddress getLocalAddress(); + + /** + * Returns the remote address for this channel. Can be null if channel does not have a remote address. + * + * @return the remote address of this channel. + */ + InetSocketAddress getRemoteAddress(); + + /** + * Closes the channel. This might be an asynchronous process. There is no guarantee that the channel + * will be closed when this method returns. + */ + void close(); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java index 7db8666e73a..ae1520cba60 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java @@ -18,20 +18,17 @@ */ package org.elasticsearch.http; -public class HttpPipelinedMessage implements Comparable { +public interface HttpPipelinedMessage extends Comparable { - private final int sequence; - - public HttpPipelinedMessage(int sequence) { - this.sequence = sequence; - } - - public int getSequence() { - return sequence; - } + /** + * Get the sequence number for this message. + * + * @return the sequence number + */ + int getSequence(); @Override - public int compareTo(HttpPipelinedMessage o) { - return Integer.compare(sequence, o.sequence); + default int compareTo(HttpPipelinedMessage o) { + return Integer.compare(getSequence(), o.getSequence()); } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java index df8bd7ee1eb..db3a2bae167 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java @@ -18,15 +18,21 @@ */ package org.elasticsearch.http; -public class HttpPipelinedRequest extends HttpPipelinedMessage { +public class HttpPipelinedRequest implements HttpPipelinedMessage { private final R request; + private final int sequence; HttpPipelinedRequest(int sequence, R request) { - super(sequence); + this.sequence = sequence; this.request = request; } + @Override + public int getSequence() { + return sequence; + } + public R getRequest() { return request; } diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java new file mode 100644 index 00000000000..496fec23312 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; + +import java.util.List; +import java.util.Map; + +/** + * A basic http request abstraction. Http modules needs to implement this interface to integrate with the + * server package's rest handling. + */ +public interface HttpRequest { + + enum HttpVersion { + HTTP_1_0, + HTTP_1_1 + } + + RestRequest.Method method(); + + /** + * The uri of the rest request, with the query string. + */ + String uri(); + + BytesReference content(); + + /** + * Get all of the headers and values associated with the headers. Modifications of this map are not supported. + */ + Map> getHeaders(); + + List strictCookies(); + + HttpVersion protocolVersion(); + + HttpRequest removeHeader(String header); + + /** + * Create an http response from this request and the supplied status and content. + */ + HttpResponse createResponse(RestStatus status, BytesReference content); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpResponse.java b/server/src/main/java/org/elasticsearch/http/HttpResponse.java new file mode 100644 index 00000000000..2d363f663c3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpResponse.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +/** + * A basic http response abstraction. Http modules must implement this interface as the server package rest + * handling needs to set http headers for a response. + */ +public interface HttpResponse { + + void addHeader(String name, String value); + + boolean containsHeader(String name); + +} diff --git a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java index d376b65ef2d..4e3d652ec5d 100644 --- a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java +++ b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java @@ -40,7 +40,7 @@ public abstract class AbstractRestChannel implements RestChannel { private static final Predicate EXCLUDE_FILTER = INCLUDE_FILTER.negate(); protected final RestRequest request; - protected final boolean detailedErrorsEnabled; + private final boolean detailedErrorsEnabled; private final String format; private final String filterPath; private final boolean pretty; diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index aae63f041fa..82fcf7178d1 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -272,8 +272,9 @@ public class RestController extends AbstractComponent implements HttpServerTrans */ private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) { if (restRequest.getXContentType() == null) { - if (restHandler.supportsContentStream() && restRequest.header("Content-Type") != null) { - final String lowercaseMediaType = restRequest.header("Content-Type").toLowerCase(Locale.ROOT); + String contentTypeHeader = restRequest.header("Content-Type"); + if (restHandler.supportsContentStream() && contentTypeHeader != null) { + final String lowercaseMediaType = contentTypeHeader.toLowerCase(Locale.ROOT); // we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/ if (lowercaseMediaType.equals("application/x-ndjson")) { restRequest.setXContentType(XContentType.JSON); diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index 65b4f9d1d36..813d6feb551 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -35,10 +35,11 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; import java.io.IOException; import java.io.InputStream; -import java.net.SocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -51,7 +52,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; -public abstract class RestRequest implements ToXContent.Params { +public class RestRequest implements ToXContent.Params { // tchar pattern as defined by RFC7230 section 3.2.6 private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+"); @@ -62,18 +63,47 @@ public abstract class RestRequest implements ToXContent.Params { private final String rawPath; private final Set consumedParams = new HashSet<>(); private final SetOnce xContentType = new SetOnce<>(); + private final HttpRequest httpRequest; + private final HttpChannel httpChannel; + + protected RestRequest(NamedXContentRegistry xContentRegistry, Map params, String path, + Map> headers, HttpRequest httpRequest, HttpChannel httpChannel) { + final XContentType xContentType; + try { + xContentType = parseContentType(headers.get("Content-Type")); + } catch (final IllegalArgumentException e) { + throw new ContentTypeHeaderException(e); + } + if (xContentType != null) { + this.xContentType.set(xContentType); + } + this.xContentRegistry = xContentRegistry; + this.httpRequest = httpRequest; + this.httpChannel = httpChannel; + this.params = params; + this.rawPath = path; + this.headers = Collections.unmodifiableMap(headers); + } + + protected RestRequest(RestRequest restRequest) { + this(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + } /** - * Creates a new REST request. + * Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be + * decoded * * @param xContentRegistry the content registry - * @param uri the raw URI that will be parsed into the path and the parameters - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws BadParameterException if the parameters can not be decoded * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map> headers) { - this(xContentRegistry, params(uri), path(uri), headers); + public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, HttpChannel httpChannel) { + Map params = params(httpRequest.uri()); + String path = path(httpRequest.uri()); + return new RestRequest(xContentRegistry, params, path, httpRequest.getHeaders(), httpRequest, httpChannel); } private static Map params(final String uri) { @@ -99,46 +129,34 @@ public abstract class RestRequest implements ToXContent.Params { } /** - * Creates a new REST request. In contrast to - * {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw - * a {@link BadParameterException}. + * Creates a new REST request. The path is not decoded so this constructor will not throw a + * {@link BadParameterException}. * * @param xContentRegistry the content registry - * @param params the request parameters - * @param path the raw path (which is not parsed) - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String path, - final Map> headers) { - final XContentType xContentType; - try { - xContentType = parseContentType(headers.get("Content-Type")); - } catch (final IllegalArgumentException e) { - throw new ContentTypeHeaderException(e); - } - if (xContentType != null) { - this.xContentType.set(xContentType); - } - this.xContentRegistry = xContentRegistry; - this.params = params; - this.rawPath = path; - this.headers = Collections.unmodifiableMap(headers); + public static RestRequest requestWithoutParameters(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, + HttpChannel httpChannel) { + Map params = Collections.emptyMap(); + return new RestRequest(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } public enum Method { GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT } - public abstract Method method(); + public Method method() { + return httpRequest.method(); + } /** * The uri of the rest request, with the query string. */ - public abstract String uri(); + public String uri() { + return httpRequest.uri(); + } /** * The non decoded, raw path provided. @@ -154,9 +172,13 @@ public abstract class RestRequest implements ToXContent.Params { return RestUtils.decodeComponent(rawPath()); } - public abstract boolean hasContent(); + public boolean hasContent() { + return content().length() > 0; + } - public abstract BytesReference content(); + public BytesReference content() { + return httpRequest.content(); + } /** * @return content of the request body or throw an exception if the body or content type is missing @@ -216,14 +238,12 @@ public abstract class RestRequest implements ToXContent.Params { this.xContentType.set(xContentType); } - @Nullable - public SocketAddress getRemoteAddress() { - return null; + public HttpChannel getHttpChannel() { + return httpChannel; } - @Nullable - public SocketAddress getLocalAddress() { - return null; + public HttpRequest getHttpRequest() { + return httpRequest; } public final boolean hasParam(String key) { diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 7e031f8d004..d0d6fa752d6 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -20,10 +20,10 @@ package org.elasticsearch.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,8 +31,7 @@ import java.util.Set; public abstract class RestResponse { - protected Map> customHeaders; - + private Map> customHeaders; /** * The response content type. @@ -81,10 +80,13 @@ public abstract class RestResponse { } /** - * Returns custom headers that have been added, or null if none have been set. + * Returns custom headers that have been added. This method should not be used to mutate headers. */ - @Nullable public Map> getHeaders() { - return customHeaders; + if (customHeaders == null) { + return Collections.emptyMap(); + } else { + return customHeaders; + } } } diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index ee74d98002f..a7629e5f48b 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -19,13 +19,27 @@ package org.elasticsearch.http; +import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkUtils; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import java.io.IOException; +import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static java.net.InetAddress.getByName; @@ -36,6 +50,27 @@ import static org.hamcrest.Matchers.equalTo; public class AbstractHttpServerTransportTests extends ESTestCase { + private NetworkService networkService; + private ThreadPool threadPool; + private MockBigArrays bigArrays; + + @Before + public void setup() throws Exception { + networkService = new NetworkService(Collections.emptyList()); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() throws Exception { + if (threadPool != null) { + threadPool.shutdownNow(); + } + threadPool = null; + networkService = null; + bigArrays = null; + } + public void testHttpPublishPort() throws Exception { int boundPort = randomIntBetween(9000, 9100); int otherBoundPort = randomIntBetween(9200, 9300); @@ -71,6 +106,64 @@ public class AbstractHttpServerTransportTests extends ESTestCase { } } + public void testDispatchDoesNotModifyThreadContext() { + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + + @Override + public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("bar", "baz"); + } + + @Override + public void dispatchBadRequest(final RestRequest request, + final RestChannel channel, + final ThreadContext threadContext, + final Throwable cause) { + threadContext.putHeader("foo_bad", "bar"); + threadContext.putTransient("bar_bad", "baz"); + } + + }; + + try (AbstractHttpServerTransport transport = + new AbstractHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher) { + @Override + protected TransportAddress bindAddress(InetAddress hostAddress) { + return null; + } + + @Override + protected void doStart() { + + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } + + @Override + public HttpStats stats() { + return null; + } + }) { + + transport.dispatchRequest(null, null, null); + assertNull(threadPool.getThreadContext().getHeader("foo")); + assertNull(threadPool.getThreadContext().getTransient("bar")); + + transport.dispatchRequest(null, null, new Exception()); + assertNull(threadPool.getThreadContext().getHeader("foo_bad")); + assertNull(threadPool.getThreadContext().getTransient("bar_bad")); + } + } + private TransportAddress address(String host, int port) throws UnknownHostException { return new TransportAddress(getByName(host), port); } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java new file mode 100644 index 00000000000..bc499ed8a42 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -0,0 +1,444 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class DefaultRestChannelTests extends ESTestCase { + + private ThreadPool threadPool; + private MockBigArrays bigArrays; + private HttpChannel httpChannel; + + @Before + public void setup() { + httpChannel = mock(HttpChannel.class); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() { + if (threadPool != null) { + threadPool.shutdownNow(); + } + } + + public void testResponse() { + final TestResponse response = executeRequest(Settings.EMPTY, "request-host"); + assertThat(response.content(), equalTo(new TestRestResponse().content())); + } + + // TODO: Enable these Cors tests when the Cors logic lives in :server + +// public void testCorsEnabledWithoutAllowOrigins() { +// // Set up a HTTP transport with only the CORS enabled setting +// Settings settings = Settings.builder() +// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, "remote-host", "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); +// } +// +// public void testCorsEnabledWithAllowOrigins() { +// final String originValue = "remote-host"; +// // create a http transport with CORS enabled and allow origin configured +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testCorsAllowOriginWithSameHost() { +// String originValue = "remote-host"; +// String host = "remote-host"; +// // create a http transport with CORS enabled +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, host); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = "http://" + originValue; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue + ":5555"; +// host = host + ":5555"; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue.replace("http", "https"); +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testThatStringLiteralWorksOnMatch() { +// final String originValue = "remote-host"; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") +// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); +// } +// +// public void testThatAnyOriginWorks() { +// final String originValue = NioCorsHandler.ANY_ORIGIN; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); +// } + + public void testHeadersSet() { + Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + TestRestResponse resp = new TestRestResponse(); + final String customHeader = "custom-header"; + final String customHeaderValue = "xyz"; + resp.addHeader(customHeader, customHeaderValue); + channel.sendResponse(resp); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse httpResponse = responseCaptor.getValue(); + Map> headers = httpResponse.headers; + assertNull(headers.get("non-existent-header")); + assertEquals(customHeaderValue, headers.get(customHeader).get(0)); + assertEquals("abc", headers.get(DefaultRestChannel.X_OPAQUE_ID).get(0)); + assertEquals(Integer.toString(resp.content().length()), headers.get(DefaultRestChannel.CONTENT_LENGTH).get(0)); + assertEquals(resp.contentType(), headers.get(DefaultRestChannel.CONTENT_TYPE).get(0)); + } + + public void testCookiesSet() { + Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse nioResponse = responseCaptor.getValue(); + Map> headers = nioResponse.headers; + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie")); + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2")); + } + + @SuppressWarnings("unchecked") + public void testReleaseInListener() throws IOException { + final Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, + JsonXContent.contentBuilder().startObject().endObject()); + assertThat(response.content(), not(instanceOf(Releasable.class))); + + // ensure we have reserved bytes + if (randomBoolean()) { + BytesStreamOutput out = channel.bytesOutput(); + assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); + } else { + try (XContentBuilder builder = channel.newBuilder()) { + // do something builder + builder.startObject().endObject(); + } + } + + channel.sendResponse(response); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released + } + + @SuppressWarnings("unchecked") + public void testConnectionClose() throws Exception { + final Settings settings = Settings.builder().build(); + final HttpRequest httpRequest; + final boolean close = randomBoolean(); + if (randomBoolean()) { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + if (close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE)); + } + } else { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/"); + if (!close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE)); + } + } + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + if (close) { + verify(httpChannel, times(1)).close(); + } else { + verify(httpChannel, times(0)).close(); + } + } + + private TestResponse executeRequest(final Settings settings, final String host) { + return executeRequest(settings, null, host); + } + + private TestResponse executeRequest(final Settings settings, final String originValue, final String host) { + HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + // TODO: These exist for the Cors tests +// if (originValue != null) { +// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); +// } +// httpRequest.headers().add(HttpHeaderNames.HOST, host); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // get the response + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any()); + return responseCaptor.getValue(); + } + + private static class TestRequest implements HttpRequest { + + private final HttpVersion version; + private final RestRequest.Method method; + private final String uri; + private HashMap> headers = new HashMap<>(); + + private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { + + this.version = version; + this.method = method; + this.uri = uri; + } + + @Override + public RestRequest.Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return BytesArray.EMPTY; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Arrays.asList("cookie", "cookie2"); + } + + @Override + public HttpVersion protocolVersion() { + return version; + } + + @Override + public HttpRequest removeHeader(String header) { + throw new UnsupportedOperationException("Do not support removing header on test request."); + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + return new TestResponse(status, content); + } + } + + private static class TestResponse implements HttpResponse { + + private final RestStatus status; + private final BytesReference content; + private final Map> headers = new HashMap<>(); + + TestResponse(RestStatus status, BytesReference content) { + this.status = status; + this.content = content; + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return status; + } + + @Override + public void addHeader(String name, String value) { + if (headers.containsKey(name) == false) { + ArrayList values = new ArrayList<>(); + values.add(value); + headers.put(name, values); + } else { + headers.get(name).add(value); + } + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + } + + private static class TestRestResponse extends RestResponse { + + private final BytesReference content; + + TestRestResponse() { + content = new BytesArray("content".getBytes(StandardCharsets.UTF_8)); + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return RestStatus.OK; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java index a0e6f702030..a80c3b1bd42 100644 --- a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java +++ b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.transport.TransportAddress; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; @@ -165,28 +164,7 @@ public class BytesRestResponseTests extends ESTestCase { public void testResponseWhenPathContainsEncodingError() throws IOException { final String path = "%a"; - final RestRequest request = - new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) { - @Override - public Method method() { - return null; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return false; - } - - @Override - public BytesReference content() { - return null; - } - }; + final RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath(path).build(); final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath())); final RestChannel channel = new DetailedExceptionRestChannel(request); // if we try to decode the path, this will throw an IllegalArgumentException again diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index f36638a4390..a090cc40b68 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -110,21 +110,21 @@ public class RestControllerTests extends ESTestCase { RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build(); final RestController spyRestController = spy(restController); when(spyRestController.getAllHandlers(fakeRequest)) - .thenReturn(new Iterator() { - @Override - public boolean hasNext() { - return false; - } + .thenReturn(new Iterator() { + @Override + public boolean hasNext() { + return false; + } - @Override - public MethodHandlers next() { - return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { - assertEquals("true", threadContext.getHeader("header.1")); - assertEquals("true", threadContext.getHeader("header.2")); - assertNull(threadContext.getHeader("header.3")); - }, RestRequest.Method.GET); - } - }); + @Override + public MethodHandlers next() { + return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { + assertEquals("true", threadContext.getHeader("header.1")); + assertEquals("true", threadContext.getHeader("header.2")); + assertNull(threadContext.getHeader("header.3")); + }, RestRequest.Method.GET); + } + }); AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST); restController.dispatchRequest(fakeRequest, channel, threadContext); // the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the @@ -136,7 +136,7 @@ public class RestControllerTests extends ESTestCase { public void testCanTripCircuitBreaker() throws Exception { RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null, null, circuitBreakerService, - usageService); + usageService); // trip circuit breaker by default controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true)); controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false)); @@ -209,7 +209,7 @@ public class RestControllerTests extends ESTestCase { return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true); }; final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null, - circuitBreakerService, usageService); + circuitBreakerService, usageService); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, Optional.of(handler)); assertTrue(wrapperCalled.get()); @@ -240,7 +240,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestAddsAndFreesBytesOnSuccess() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -252,7 +252,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestAddsAndFreesBytesOnError() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -265,7 +265,7 @@ public class RestControllerTests extends ESTestCase { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); // we will produce an error in the rest handler and one more when sending the error response - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -277,7 +277,7 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequestLimitsBytes() { int contentLength = BREAKER_LIMIT.bytesAsInt() + 1; String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -288,11 +288,11 @@ public class RestControllerTests extends ESTestCase { public void testDispatchRequiresContentTypeForRequestsWithContent() { String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt()); - TestRestRequest request = new TestRestRequest("/", content, null); + RestRequest request = testRestRequest("/", content, null); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE); restController = new RestController( Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(), - Collections.emptySet(), null, null, circuitBreakerService, usageService); + Collections.emptySet(), null, null, circuitBreakerService, usageService); restController.registerHandler(RestRequest.Method.GET, "/", (r, c, client) -> c.sendResponse( new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))); @@ -412,8 +412,8 @@ public class RestControllerTests extends ESTestCase { public void testNonStreamingXContentCausesErrorResponse() throws IOException { FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), - XContentType.YAML).withPath("/foo").build(); + .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), + XContentType.YAML).withPath("/foo").build(); AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE); restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() { @Override @@ -457,10 +457,10 @@ public class RestControllerTests extends ESTestCase { final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST); restController.dispatchBadRequest( - fakeRestRequest, - channel, - new ThreadContext(Settings.EMPTY), - randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); + fakeRestRequest, + channel, + new ThreadContext(Settings.EMPTY), + randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); assertTrue(channel.getSendResponseCalled()); assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request")); } @@ -495,7 +495,7 @@ public class RestControllerTests extends ESTestCase { @Override public BoundTransportAddress boundAddress() { TransportAddress transportAddress = buildNewFakeTransportAddress(); - return new BoundTransportAddress(new TransportAddress[] {transportAddress} ,transportAddress); + return new BoundTransportAddress(new TransportAddress[]{transportAddress}, transportAddress); } @Override @@ -547,35 +547,11 @@ public class RestControllerTests extends ESTestCase { } } - private static final class TestRestRequest extends RestRequest { - - private final BytesReference content; - - private TestRestRequest(String path, String content, XContentType xContentType) { - super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ? - Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType()))); - this.content = new BytesArray(content); - } - - @Override - public Method method() { - return Method.GET; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return true; - } - - @Override - public BytesReference content() { - return content; - } - + private static RestRequest testRestRequest(String path, String content, XContentType xContentType) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withPath(path); + builder.withContent(new BytesArray(content), xContentType); + return builder.build(); } } + diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index 1b4bbff7322..3ad9c61de3c 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.collect.MapBuilder; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; import java.io.IOException; import java.util.ArrayList; @@ -44,66 +45,66 @@ import static org.hamcrest.Matchers.instanceOf; public class RestRequestTests extends ESTestCase { public void testContentParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentParser()); + contentRestRequest("", emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", singletonMap("source", "{}")).contentParser()); + contentRestRequest("", singletonMap("source", "{}")).contentParser()); assertEquals("request body is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentParser().map()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap(), emptyMap()).contentParser()); + contentRestRequest("", emptyMap(), emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); } public void testApplyContentParser() throws IOException { - new ContentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); - new ContentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); AtomicReference source = new AtomicReference<>(); - new ContentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); + contentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); assertEquals(emptyMap(), source.get()); } public void testContentOrSourceParam() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParam()); + contentRestRequest("", emptyMap()).contentOrSourceParam()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("{\"foo\": \"stuff\"}"), - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .contentOrSourceParam().v2()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "stuff2").immutableMap()).contentOrSourceParam()); assertEquals("source and source_content_type parameters are required", e.getMessage()); } public void testHasContentOrSourceParam() throws IOException { - assertEquals(false, new ContentRestRequest("", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); + assertEquals(false, contentRestRequest("", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); } public void testContentOrSourceParamParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParamParser()); + contentRestRequest("", emptyMap()).contentOrSourceParamParser()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("", MapBuilder.newMapBuilder() + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map()); } public void testWithContentOrSourceParamParserOrNull() throws IOException { - new ContentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); - new ContentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> + contentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); + contentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); + contentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") + contentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") .put("source", "{}").immutableMap()) .withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); @@ -113,18 +114,18 @@ public class RestRequestTests extends ESTestCase { for (XContentType xContentType : XContentType.values()) { Map> map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaType())); - ContentRestRequest restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + RestRequest restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters())); - restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); } } public void testPlainTextSupport() { - ContentRestRequest restRequest = new ContentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), + RestRequest restRequest = contentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8")))); assertNull(restRequest.getXContentType()); @@ -136,7 +137,7 @@ public class RestRequestTests extends ESTestCase { RestRequest.ContentTypeHeaderException.class, () -> { final Map> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type)); - new ContentRestRequest("", Collections.emptyMap(), headers); + contentRestRequest("", Collections.emptyMap(), headers); }); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf(IllegalArgumentException.class)); @@ -144,7 +145,7 @@ public class RestRequestTests extends ESTestCase { } public void testNoContentTypeHeader() { - ContentRestRequest contentRestRequest = new ContentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); + RestRequest contentRestRequest = contentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); assertNull(contentRestRequest.getXContentType()); } @@ -152,7 +153,7 @@ public class RestRequestTests extends ESTestCase { List headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10))); final RestRequest.ContentTypeHeaderException e = expectThrows( RestRequest.ContentTypeHeaderException.class, - () -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); + () -> contentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf((IllegalArgumentException.class))); assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided")); @@ -160,52 +161,64 @@ public class RestRequestTests extends ESTestCase { public void testRequiredContent() { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).requiredContent()); + contentRestRequest("", emptyMap()).requiredContent()); assertEquals("request body is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).requiredContent()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).requiredContent()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .requiredContent()); assertEquals("request body is required", e.getMessage()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("test", null, Collections.emptyMap()).requiredContent()); + contentRestRequest("test", null, Collections.emptyMap()).requiredContent()); assertEquals("unknown content type", e.getMessage()); } + private static RestRequest contentRestRequest(String content, Map params) { + Map> headers = new HashMap<>(); + headers.put("Content-Type", Collections.singletonList("application/json")); + return contentRestRequest(content, params, headers); + } + + private static RestRequest contentRestRequest(String content, Map params, Map> headers) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(headers); + builder.withContent(new BytesArray(content), null); + builder.withParams(params); + return new ContentRestRequest(builder.build()); + } + private static final class ContentRestRequest extends RestRequest { - private final BytesArray content; - ContentRestRequest(String content, Map params) { - this(content, params, Collections.singletonMap("Content-Type", Collections.singletonList("application/json"))); - } + private final RestRequest restRequest; - ContentRestRequest(String content, Map params, Map> headers) { - super(NamedXContentRegistry.EMPTY, params, "not used by this test", headers); - this.content = new BytesArray(content); - } - - @Override - public boolean hasContent() { - return Strings.hasLength(content); - } - - @Override - public BytesReference content() { - return content; - } - - @Override - public String uri() { - throw new UnsupportedOperationException("Not used by this test"); + private ContentRestRequest(RestRequest restRequest) { + super(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + this.restRequest = restRequest; } @Override public Method method() { - throw new UnsupportedOperationException("Not used by this test"); + return restRequest.method(); + } + + @Override + public String uri() { + return restRequest.uri(); + } + + @Override + public boolean hasContent() { + return Strings.hasLength(content()); + } + + @Override + public BytesReference content() { + return restRequest.content(); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index d0403736400..4d4743156c7 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -19,12 +19,18 @@ package org.elasticsearch.test.rest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; -import java.net.SocketAddress; +import java.net.InetSocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -32,45 +38,115 @@ import java.util.Map; public class FakeRestRequest extends RestRequest { - private final BytesReference content; - private final Method method; - private final SocketAddress remoteAddress; - public FakeRestRequest() { - this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null); + this(NamedXContentRegistry.EMPTY, new FakeHttpRequest(Method.GET, "", BytesArray.EMPTY, new HashMap<>()), new HashMap<>(), + new FakeHttpChannel(null)); } - private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map> headers, - Map params, BytesReference content, Method method, String path, SocketAddress remoteAddress) { - super(xContentRegistry, params, path, headers); - this.content = content; - this.method = method; - this.remoteAddress = remoteAddress; - } - - @Override - public Method method() { - return method; - } - - @Override - public String uri() { - return rawPath(); + private FakeRestRequest(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, Map params, + HttpChannel httpChannel) { + super(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } @Override public boolean hasContent() { - return content != null; + return content() != null; } - @Override - public BytesReference content() { - return content; + private static class FakeHttpRequest implements HttpRequest { + + private final Method method; + private final String uri; + private final BytesReference content; + private final Map> headers; + + private FakeHttpRequest(Method method, String uri, BytesReference content, Map> headers) { + this.method = method; + this.uri = uri; + this.content = content; + this.headers = headers; + } + + @Override + public Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return content; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + return HttpVersion.HTTP_1_1; + } + + @Override + public HttpRequest removeHeader(String header) { + headers.remove(header); + return this; + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + Map headers = new HashMap<>(); + return new HttpResponse() { + @Override + public void addHeader(String name, String value) { + headers.put(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + }; + } } - @Override - public SocketAddress getRemoteAddress() { - return remoteAddress; + private static class FakeHttpChannel implements HttpChannel { + + private final InetSocketAddress remoteAddress; + + private FakeHttpChannel(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + + @Override + public void sendResponse(HttpResponse response, ActionListener listener) { + + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return remoteAddress; + } + + @Override + public void close() { + + } } public static class Builder { @@ -86,7 +162,7 @@ public class FakeRestRequest extends RestRequest { private Method method = Method.GET; - private SocketAddress address = null; + private InetSocketAddress address = null; public Builder(NamedXContentRegistry xContentRegistry) { this.xContentRegistry = xContentRegistry; @@ -120,15 +196,14 @@ public class FakeRestRequest extends RestRequest { return this; } - public Builder withRemoteAddress(SocketAddress address) { + public Builder withRemoteAddress(InetSocketAddress address) { this.address = address; return this; } public FakeRestRequest build() { - return new FakeRestRequest(xContentRegistry, headers, params, content, method, path, address); + FakeHttpRequest fakeHttpRequest = new FakeHttpRequest(method, path, content, headers); + return new FakeRestRequest(xContentRegistry, fakeHttpRequest, params, new FakeHttpChannel(address)); } - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java index aec5b3a04d2..71424ec507f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.security.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -17,7 +16,6 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.rest.RestRequest; import java.io.IOException; -import java.net.SocketAddress; import java.util.Map; import java.util.Set; @@ -33,37 +31,15 @@ public interface RestRequestFilter { default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException { Set fields = getFilteredFields(); if (restRequest.hasContent() && fields.isEmpty() == false) { - return new RestRequest(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders()) { + return new RestRequest(restRequest) { private BytesReference filteredBytes = null; - @Override - public Method method() { - return restRequest.method(); - } - - @Override - public String uri() { - return restRequest.uri(); - } - @Override public boolean hasContent() { return true; } - @Nullable - @Override - public SocketAddress getRemoteAddress() { - return restRequest.getRemoteAddress(); - } - - @Nullable - @Override - public SocketAddress getLocalAddress() { - return restRequest.getLocalAddress(); - } - @Override public BytesReference content() { if (filteredBytes == null) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java index 1976722d65f..1991c2685f2 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java @@ -69,7 +69,6 @@ import java.io.Closeable; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -829,10 +828,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } @@ -854,10 +852,9 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 3b9a42179a5..5706f79011a 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.security.transport.filter.SecurityIpFilterRule; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; import java.util.EnumSet; @@ -544,13 +543,8 @@ public class LoggingAuditTrail extends AbstractComponent implements AuditTrail, } private static String hostAttributes(RestRequest request) { - String formattedAddress; - final SocketAddress socketAddress = request.getRemoteAddress(); - if (socketAddress instanceof InetSocketAddress) { - formattedAddress = NetworkAddress.format(((InetSocketAddress) socketAddress).getAddress()); - } else { - formattedAddress = socketAddress.toString(); - } + final InetSocketAddress socketAddress = request.getHttpChannel().getRemoteAddress(); + String formattedAddress = NetworkAddress.format(socketAddress.getAddress()); return "origin_address=[" + formattedAddress + "]"; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java index dcee6535cf3..ed50a5cfe84 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java @@ -20,7 +20,7 @@ public class RemoteHostHeader { * then be copied to the subsequent action requests. */ public static void process(RestRequest request, ThreadContext threadContext) { - threadContext.putTransient(KEY, request.getRemoteAddress()); + threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress()); } /** diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index 0f4da8b847c..9109bb37e8c 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.rest; +import io.netty.channel.Channel; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -13,7 +14,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.logging.ESLoggerFactory; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.netty4.Netty4HttpRequest; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.netty4.Netty4HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -50,10 +52,11 @@ public class SecurityRestFilter implements RestHandler { if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) { // CORS - allow for preflight unauthenticated OPTIONS request if (extractClientCertificate) { - Netty4HttpRequest nettyHttpRequest = (Netty4HttpRequest) request; - SslHandler handler = nettyHttpRequest.getChannel().pipeline().get(SslHandler.class); + HttpChannel httpChannel = request.getHttpChannel(); + Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel(); + SslHandler handler = nettyChannel.pipeline().get(SslHandler.class); assert handler != null; - ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel()); + ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyChannel); } service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap( authentication -> { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java index 01916b91380..ac586c49457 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java @@ -104,7 +104,7 @@ public class SecurityNetty4HttpServerTransport extends Netty4HttpServerTransport private final class HttpSslChannelHandler extends HttpChannelHandler { HttpSslChannelHandler() { - super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext()); + super(SecurityNetty4HttpServerTransport.this, handlingSettings); } @Override diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java index 7878fdb9233..2e2a931f78f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.plugins.MetaDataUpgrader; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestRequest; @@ -914,7 +915,9 @@ public class IndexAuditTrailTests extends SecurityIntegTestCase { private RestRequest mockRestRequest() { RestRequest request = mock(RestRequest.class); - when(request.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); + HttpChannel httpChannel = mock(HttpChannel.class); + when(request.getHttpChannel()).thenReturn(httpChannel); + when(httpChannel.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); when(request.uri()).thenReturn("_uri"); return request; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java index 335673f1c0c..127784dcfc0 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java @@ -88,6 +88,6 @@ public class RestRequestFilterTests extends ESTestCase { new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON) .withRemoteAddress(address).build(); RestRequest filtered = filter.getFilteredRequest(restRequest); - assertEquals(address, filtered.getRemoteAddress()); + assertEquals(address, filtered.getHttpChannel().getRemoteAddress()); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java index 2857aee9b61..5db634c8d7b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -67,6 +68,7 @@ public class SecurityRestFilterTests extends ESTestCase { public void testProcess() throws Exception { RestRequest request = mock(RestRequest.class); + when(request.getHttpChannel()).thenReturn(mock(HttpChannel.class)); Authentication authentication = mock(Authentication.class); doAnswer((i) -> { ActionListener callback = From d6d0727aac0159a926b80ecce95913b928834267 Mon Sep 17 00:00:00 2001 From: Nik Everett Date: Thu, 14 Jun 2018 18:25:47 -0400 Subject: [PATCH 8/8] QA: Fix resolution of default distribution (#31351) If you run `./gradlew -p qa bwcTest -Dtests.distribution=zip` then we need to resolve older versions of the default distribution. Since those aren't available in maven central, we need add the elastic maven repo to the project. --- qa/build.gradle | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/qa/build.gradle b/qa/build.gradle index 709c309359e..0336b947d06 100644 --- a/qa/build.gradle +++ b/qa/build.gradle @@ -5,6 +5,20 @@ subprojects { Project subproj -> subproj.tasks.withType(RestIntegTestTask) { subproj.extensions.configure("${it.name}Cluster") { cluster -> cluster.distribution = System.getProperty('tests.distribution', 'oss-zip') + if (cluster.distribution == 'zip') { + /* + * Add Elastic's repositories so we can resolve older versions of the + * default distribution. Those aren't in maven central. + */ + repositories { + maven { + url "https://artifacts.elastic.co/maven" + } + maven { + url "https://snapshots.elastic.co/maven" + } + } + } } } }