Allow to set headers in HTTP response

This commit allows to set custom headers in HTTP responses (like
setting the WWW-Authenticate header for basic auth) by adding
RestRequest.addHeader() method.

Closes #2936
Closes #2540

To get the history right: This is based on PR #2723
This commit is contained in:
Alexander Reelsen 2013-05-03 09:22:34 +02:00
parent da5dff9ee4
commit 21fcc482eb
10 changed files with 366 additions and 63 deletions

View File

@ -36,6 +36,8 @@ import org.jboss.netty.channel.ChannelFutureListener;
import org.jboss.netty.handler.codec.http.*;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
@ -90,6 +92,16 @@ public class NettyHttpChannel implements HttpChannel {
resp.addHeader("X-Opaque-Id", opaque);
}
// Add all custom headers
Map<String, List<String>> customHeaders = response.getHeaders();
if (customHeaders != null) {
for (Map.Entry<String, List<String>> headerEntry : customHeaders.entrySet()) {
for (String headerValue : headerEntry.getValue()) {
resp.addHeader(headerEntry.getKey(), headerValue);
}
}
}
// Convert the response content to a ChannelBuffer.
ChannelFutureListener releaseContentListener = null;
ChannelBuffer buf;

View File

@ -19,11 +19,18 @@
package org.elasticsearch.rest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
/**
*
*/
public abstract class AbstractRestResponse implements RestResponse {
Map<String, List<String>> customHeaders;
@Override
public byte[] prefixContent() {
return null;
@ -53,4 +60,22 @@ public abstract class AbstractRestResponse implements RestResponse {
public int suffixContentOffset() {
return 0;
}
@Override
public void addHeader(String name, String value) {
if (customHeaders == null) {
customHeaders = new HashMap<String, List<String>>(2);
}
List<String> header = customHeaders.get(name);
if (header == null) {
header = new ArrayList<String>();
customHeaders.put(name, header);
}
header.add(value);
}
@Override
public Map<String, List<String>> getHeaders() {
return customHeaders;
}
}

View File

@ -20,6 +20,8 @@
package org.elasticsearch.rest;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
*
@ -59,4 +61,11 @@ public interface RestResponse {
int suffixContentOffset();
RestStatus status();
void addHeader(String name, String value);
/**
* @return The custom headers or null if none have been set
*/
Map<String, List<String>> getHeaders();
}

View File

@ -19,6 +19,12 @@
package org.elasticsearch.test.integration.nodesinfo;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse;
import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse;
import org.elasticsearch.action.admin.cluster.node.info.PluginInfo;
@ -35,7 +41,10 @@ import org.testng.annotations.Test;
import java.io.File;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.Collections;
import java.util.List;
import static com.google.common.base.Predicates.*;
import static org.elasticsearch.client.Requests.clusterHealthRequest;
import static org.elasticsearch.client.Requests.nodesInfoRequest;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
@ -72,29 +81,29 @@ public class SimpleNodesInfoTests extends AbstractNodesTests {
logger.info("--> started nodes: " + server1NodeId + " and " + server2NodeId);
NodesInfoResponse response = client("server1").admin().cluster().prepareNodesInfo().execute().actionGet();
assertThat(response.getNodes().length, equalTo(2));
assertThat(response.getNodes().length, is(2));
assertThat(response.getNodesMap().get(server1NodeId), notNullValue());
assertThat(response.getNodesMap().get(server2NodeId), notNullValue());
response = client("server2").admin().cluster().nodesInfo(nodesInfoRequest()).actionGet();
assertThat(response.getNodes().length, equalTo(2));
assertThat(response.getNodes().length, is(2));
assertThat(response.getNodesMap().get(server1NodeId), notNullValue());
assertThat(response.getNodesMap().get(server2NodeId), notNullValue());
response = client("server1").admin().cluster().nodesInfo(nodesInfoRequest(server1NodeId)).actionGet();
assertThat(response.getNodes().length, equalTo(1));
assertThat(response.getNodes().length, is(1));
assertThat(response.getNodesMap().get(server1NodeId), notNullValue());
response = client("server2").admin().cluster().nodesInfo(nodesInfoRequest(server1NodeId)).actionGet();
assertThat(response.getNodes().length, equalTo(1));
assertThat(response.getNodes().length, is(1));
assertThat(response.getNodesMap().get(server1NodeId), notNullValue());
response = client("server1").admin().cluster().nodesInfo(nodesInfoRequest(server2NodeId)).actionGet();
assertThat(response.getNodes().length, equalTo(1));
assertThat(response.getNodes().length, is(1));
assertThat(response.getNodesMap().get(server2NodeId), notNullValue());
response = client("server2").admin().cluster().nodesInfo(nodesInfoRequest(server2NodeId)).actionGet();
assertThat(response.getNodes().length, equalTo(1));
assertThat(response.getNodes().length, is(1));
assertThat(response.getNodesMap().get(server2NodeId), notNullValue());
}
@ -127,68 +136,64 @@ public class SimpleNodesInfoTests extends AbstractNodesTests {
NodesInfoResponse response = client("node1").admin().cluster().prepareNodesInfo().setPlugin(true).execute().actionGet();
logger.info("--> full json answer, status " + response.toString());
checkPlugin(response, server1NodeId, 0, 0);
checkPlugin(response, server2NodeId, 1, 0);
checkPlugin(response, server3NodeId, 0, 1);
checkPlugin(response, server4NodeId, 1, 2); // Note that we have now 2 JVM plugins as we have already loaded one with node3
assertNodeContainsPlugins(response, server1NodeId, Collections.EMPTY_LIST, Collections.EMPTY_LIST,
Collections.EMPTY_LIST, Collections.EMPTY_LIST);
assertNodeContainsPlugins(response, server2NodeId, Collections.EMPTY_LIST, Collections.EMPTY_LIST,
Lists.newArrayList(Fields.SITE_PLUGIN),
Lists.newArrayList(Fields.SITE_PLUGIN_DESCRIPTION));
assertNodeContainsPlugins(response, server3NodeId, Lists.newArrayList(TestPlugin.Fields.NAME),
Lists.newArrayList(TestPlugin.Fields.DESCRIPTION),
Collections.EMPTY_LIST, Collections.EMPTY_LIST);
// Note that we have now 2 JVM plugins as we have already loaded one with node3
assertNodeContainsPlugins(response, server4NodeId,
Lists.newArrayList(TestPlugin.Fields.NAME, TestNoVersionPlugin.Fields.NAME),
Lists.newArrayList(TestPlugin.Fields.DESCRIPTION, TestNoVersionPlugin.Fields.DESCRIPTION),
Lists.newArrayList(Fields.SITE_PLUGIN, TestNoVersionPlugin.Fields.NAME),
Lists.newArrayList(Fields.SITE_PLUGIN_NO_DESCRIPTION, TestNoVersionPlugin.Fields.DESCRIPTION));
}
/**
* We check infos
* @param response Response
* @param nodeId NodeId we want to check
* @param expectedSitePlugins Number of site plugins expected
* @param expectedJvmPlugins Number of jvm plugins expected
*/
private void checkPlugin(NodesInfoResponse response, String nodeId,
int expectedSitePlugins, int expectedJvmPlugins) {
private void assertNodeContainsPlugins(NodesInfoResponse response, String nodeId,
List<String> expectedJvmPluginNames,
List<String> expectedJvmPluginDescriptions,
List<String> expectedSitePluginNames,
List<String> expectedSitePluginDescriptions) {
assertThat(response.getNodesMap().get(nodeId), notNullValue());
PluginsInfo plugins = response.getNodesMap().get(nodeId).getPlugins();
assertThat(plugins, notNullValue());
int num_site_plugins = 0;
int num_jvm_plugins = 0;
for (PluginInfo pluginInfo : plugins.getInfos()) {
// It should be a site or a jvm plugin
assertThat(pluginInfo.isJvm() || pluginInfo.isSite(), is(true));
if (pluginInfo.isSite() && !pluginInfo.isJvm()) {
// Let's do some tests for site plugins
assertThat(pluginInfo.getName(), isOneOf(Fields.SITE_PLUGIN,
TestNoVersionPlugin.Fields.NAME));
assertThat(pluginInfo.getDescription(),
isOneOf(Fields.SITE_PLUGIN_DESCRIPTION,
Fields.SITE_PLUGIN_NO_DESCRIPTION,
Fields.JVM_PLUGIN_NO_DESCRIPTION));
assertThat(pluginInfo.getUrl(), notNullValue());
num_site_plugins++;
}
if (pluginInfo.isJvm() && !pluginInfo.isSite()) {
// Let's do some tests for site plugins
assertThat(pluginInfo.getName(),
isOneOf(TestPlugin.Fields.NAME, TestNoVersionPlugin.Fields.NAME));
assertThat(pluginInfo.getDescription(),
isOneOf(TestPlugin.Fields.DESCRIPTION, TestNoVersionPlugin.Fields.DESCRIPTION));
assertThat(pluginInfo.getUrl(), nullValue());
num_jvm_plugins++;
}
// On node4, test-no-version-plugin has an embedded _site structure
if (pluginInfo.isJvm() && pluginInfo.isSite()) {
assertThat(pluginInfo.getName(),
is(TestNoVersionPlugin.Fields.NAME));
assertThat(pluginInfo.getDescription(),
is(TestNoVersionPlugin.Fields.DESCRIPTION));
assertThat(pluginInfo.getUrl(), notNullValue());
num_jvm_plugins++;
}
List<String> pluginNames = FluentIterable.from(plugins.getInfos()).filter(jvmPluginPredicate).transform(nameFunction).toList();
for (String expectedJvmPluginName : expectedJvmPluginNames) {
assertThat(pluginNames, hasItem(expectedJvmPluginName));
}
assertThat(num_site_plugins, is(expectedSitePlugins));
assertThat(num_jvm_plugins, is(expectedJvmPlugins));
List<String> pluginDescriptions = FluentIterable.from(plugins.getInfos()).filter(jvmPluginPredicate).transform(descriptionFunction).toList();
for (String expectedJvmPluginDescription : expectedJvmPluginDescriptions) {
assertThat(pluginDescriptions, hasItem(expectedJvmPluginDescription));
}
FluentIterable<String> jvmUrls = FluentIterable.from(plugins.getInfos())
.filter(and(jvmPluginPredicate, Predicates.not(sitePluginPredicate)))
.filter(isNull())
.transform(urlFunction);
assertThat(Iterables.size(jvmUrls), is(0));
List<String> sitePluginNames = FluentIterable.from(plugins.getInfos()).filter(sitePluginPredicate).transform(nameFunction).toList();
for (String expectedSitePluginName : expectedSitePluginNames) {
assertThat(sitePluginNames, hasItem(expectedSitePluginName));
}
List<String> sitePluginDescriptions = FluentIterable.from(plugins.getInfos()).filter(sitePluginPredicate).transform(descriptionFunction).toList();
for (String sitePluginDescription : expectedSitePluginDescriptions) {
assertThat(sitePluginDescriptions, hasItem(sitePluginDescription));
}
List<String> sitePluginUrls = FluentIterable.from(plugins.getInfos()).filter(sitePluginPredicate).transform(urlFunction).toList();
assertThat(sitePluginUrls, not(contains(nullValue())));
}
private String startNodeWithPlugins(String name) throws URISyntaxException {
@ -209,4 +214,34 @@ public class SimpleNodesInfoTests extends AbstractNodesTests {
return serverNodeId;
}
private Predicate<PluginInfo> jvmPluginPredicate = new Predicate<PluginInfo>() {
public boolean apply(PluginInfo pluginInfo) {
return pluginInfo.isJvm();
}
};
private Predicate<PluginInfo> sitePluginPredicate = new Predicate<PluginInfo>() {
public boolean apply(PluginInfo pluginInfo) {
return pluginInfo.isSite();
}
};
private Function<PluginInfo, String> nameFunction = new Function<PluginInfo, String>() {
public String apply(PluginInfo pluginInfo) {
return pluginInfo.getName();
}
};
private Function<PluginInfo, String> descriptionFunction = new Function<PluginInfo, String>() {
public String apply(PluginInfo pluginInfo) {
return pluginInfo.getDescription();
}
};
private Function<PluginInfo, String> urlFunction = new Function<PluginInfo, String>() {
public String apply(PluginInfo pluginInfo) {
return pluginInfo.getUrl();
}
};
}

View File

@ -0,0 +1,84 @@
/*
* Licensed to Elastic Search and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. ElasticSearch licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.test.integration.plugin;
import static org.elasticsearch.client.Requests.clusterHealthRequest;
import static org.elasticsearch.common.settings.ImmutableSettings.settingsBuilder;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import com.beust.jcommander.internal.Maps;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.node.internal.InternalNode;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.integration.AbstractNodesTests;
import org.elasticsearch.test.integration.rest.helper.HttpClient;
import org.elasticsearch.test.integration.rest.helper.HttpClientResponse;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.io.File;
import java.net.URL;
import java.util.Map;
/**
* Test a rest action that sets special response headers
*/
public class ResponseHeaderPluginTests extends AbstractNodesTests {
public static final String NODE_ID = "TEST";
@BeforeMethod
public void startNode() throws Exception {
URL resource = ResponseHeaderPluginTests.class.getResource("/org/elasticsearch/test/integration/responseheader/");
ImmutableSettings.Builder settings = settingsBuilder();
if (resource != null) {
settings.put("path.plugins", new File(resource.toURI()).getAbsolutePath());
}
startNode(NODE_ID, settings);
client(NODE_ID).admin().cluster().health(clusterHealthRequest().waitForGreenStatus()).actionGet();
}
@AfterMethod
public void closeNodes() {
closeAllNodes();
}
@Test
public void testThatSettingHeadersWorks() throws Exception {
HttpClientResponse response = httpClient().request("/_protected");
assertThat(response.errorCode(), equalTo(RestStatus.UNAUTHORIZED.getStatus()));
assertThat(response.getHeader("Secret"), equalTo("required"));
Map<String, String> headers = Maps.newHashMap();
headers.put("Secret", "password");
HttpClientResponse authResponse = httpClient().request("GET", "_protected", headers);
assertThat(authResponse.errorCode(), equalTo(RestStatus.OK.getStatus()));
assertThat(authResponse.getHeader("Secret"), equalTo("granted"));
}
private HttpClient httpClient() {
HttpServerTransport httpServerTransport = ((InternalNode) node(NODE_ID)).injector().getInstance(HttpServerTransport.class);
return new HttpClient(httpServerTransport.boundAddress().publishAddress());
}
}

View File

@ -0,0 +1,40 @@
/*
* Licensed to ElasticSearch and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. ElasticSearch licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.test.integration.plugin.responseheader;
import org.elasticsearch.plugins.AbstractPlugin;
import org.elasticsearch.rest.RestModule;
public class TestResponseHeaderPlugin extends AbstractPlugin {
@Override
public String name() {
return "test-plugin-custom-header";
}
@Override
public String description() {
return "test-plugin-custom-header-desc";
}
public void onModule(RestModule restModule) {
restModule.addRestAction(TestResponseHeaderRestAction.class);
}
}

View File

@ -0,0 +1,46 @@
/*
* Licensed to ElasticSearch and Shay Banon under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. ElasticSearch licenses this
* file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.test.integration.plugin.responseheader;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.*;
public class TestResponseHeaderRestAction extends BaseRestHandler {
@Inject
public TestResponseHeaderRestAction(Settings settings, Client client, RestController controller) {
super(settings, client);
controller.registerHandler(RestRequest.Method.GET, "/_protected", this);
}
@Override
public void handleRequest(RestRequest request, RestChannel channel) {
if ("password".equals(request.header("Secret"))) {
RestResponse response = new StringRestResponse(RestStatus.OK, "Access granted");
response.addHeader("Secret", "granted");
channel.sendResponse(response);
} else {
RestResponse response = new StringRestResponse(RestStatus.UNAUTHORIZED, "Access denied");
response.addHeader("Secret", "required");
channel.sendResponse(response);
}
}
}

View File

@ -30,6 +30,8 @@ import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.Map;
import java.util.List;
public class HttpClient {
@ -61,6 +63,10 @@ public class HttpClient {
}
public HttpClientResponse request(String method, String path) {
return request(method, path, null);
}
public HttpClientResponse request(String method, String path, Map<String, String> headers) {
URL url;
try {
url = new URL(baseUrl, path);
@ -72,14 +78,21 @@ public class HttpClient {
try {
urlConnection = (HttpURLConnection) url.openConnection();
urlConnection.setRequestMethod(method);
if (headers != null) {
for (Map.Entry<String,String> headerEntry : headers.entrySet()) {
urlConnection.setRequestProperty(headerEntry.getKey(), headerEntry.getValue());
}
}
urlConnection.connect();
} catch (IOException e) {
throw new ElasticSearchException("", e);
}
int errorCode = -1;
Map<String, List<String>> respHeaders = null;
try {
errorCode = urlConnection.getResponseCode();
respHeaders = urlConnection.getHeaderFields();
InputStream inputStream = urlConnection.getInputStream();
String body = null;
try {
@ -87,7 +100,7 @@ public class HttpClient {
} catch (IOException e1) {
throw new ElasticSearchException("problem reading error stream", e1);
}
return new HttpClientResponse(body, errorCode, null);
return new HttpClientResponse(body, errorCode, respHeaders, null);
} catch (IOException e) {
InputStream errStream = urlConnection.getErrorStream();
String body = null;
@ -96,8 +109,7 @@ public class HttpClient {
} catch (IOException e1) {
throw new ElasticSearchException("problem reading error stream", e1);
}
return new HttpClientResponse(body, errorCode, e);
return new HttpClientResponse(body, errorCode, respHeaders, e);
} finally {
urlConnection.disconnect();
}

View File

@ -18,15 +18,20 @@
*/
package org.elasticsearch.test.integration.rest.helper;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
public class HttpClientResponse {
private final String response;
private final int errorCode;
private Map<String, List<String>> headers;
private final Throwable e;
public HttpClientResponse(String response, int errorCode, Throwable e) {
public HttpClientResponse(String response, int errorCode, Map<String, List<String>> headers, Throwable e) {
this.response = response;
this.errorCode = errorCode;
this.headers = headers;
this.e = e;
}
@ -41,4 +46,19 @@ public class HttpClientResponse {
public Throwable cause() {
return e;
}
public Map<String, List<String>> getHeaders() {
return headers;
}
public String getHeader(String name) {
if (headers == null) {
return null;
}
List<String> vals = headers.get(name);
if (vals == null || vals.size() == 0) {
return null;
}
return vals.iterator().next();
}
}

View File

@ -0,0 +1,20 @@
################################################################
# Licensed to ElasticSearch and Shay Banon under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. ElasticSearch licenses this
# file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
################################################################
plugin=org.elasticsearch.test.integration.plugin.responseheader.TestResponseHeaderPlugin