Refactor PathTrie and RestController to use a single trie for all methods (#25459)

* Refactor PathTrie and RestController to use a single trie for all methods

This changes `PathTrie` and `RestController` to use a single `PathTrie` for all
endpoints, it also allows retrieving the endpoints' supported HTTP methods more
easily.

This is a spin-off and prerequisite of #24437

* Use EnumSet instead of multiple if conditions

* Make MethodHandlers package-private and final

* Remove duplicate registerHandler method

* Remove public modifier
This commit is contained in:
Lee Hinman 2017-07-05 17:28:10 -06:00 committed by GitHub
parent 6e5cc424a8
commit 30b5ca7ab7
7 changed files with 542 additions and 120 deletions

View File

@ -19,14 +19,49 @@
package org.elasticsearch.common.path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;
public class PathTrie<T> {
enum TrieMatchingMode {
/*
* Retrieve only explicitly mapped nodes, no wildcards are
* matched.
*/
EXPLICIT_NODES_ONLY,
/*
* Retrieve only explicitly mapped nodes, with wildcards
* allowed as root nodes.
*/
WILDCARD_ROOT_NODES_ALLOWED,
/*
* Retrieve only explicitly mapped nodes, with wildcards
* allowed as leaf nodes.
*/
WILDCARD_LEAF_NODES_ALLOWED,
/*
* Retrieve both explicitly mapped and wildcard nodes.
*/
WILDCARD_NODES_ALLOWED
}
static EnumSet<TrieMatchingMode> EXPLICIT_OR_ROOT_WILDCARD =
EnumSet.of(TrieMatchingMode.EXPLICIT_NODES_ONLY, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED);
public interface Decoder {
String decode(String value);
}
@ -107,15 +142,15 @@ public class PathTrie<T> {
if (isNamedWildcard(token)) {
node.updateKeyWithNamedWildcard(token);
}
// in case the target(last) node already exist but without a value
// than the value should be updated.
/*
* If the target node already exists, but is without a value,
* then the value should be updated.
*/
if (index == (path.length - 1)) {
if (node.value != null) {
throw new IllegalArgumentException("Path [" + String.join("/", path)+ "] already has a value ["
+ node.value + "]");
}
if (node.value == null) {
} else {
node.value = value;
}
}
@ -124,6 +159,40 @@ public class PathTrie<T> {
node.insert(path, index + 1, value);
}
public synchronized void insertOrUpdate(String[] path, int index, T value, BiFunction<T, T, T> updater) {
if (index >= path.length)
return;
String token = path[index];
String key = token;
if (isNamedWildcard(token)) {
key = wildcard;
}
TrieNode node = children.get(key);
if (node == null) {
T nodeValue = index == path.length - 1 ? value : null;
node = new TrieNode(token, nodeValue, wildcard);
addInnerChild(key, node);
} else {
if (isNamedWildcard(token)) {
node.updateKeyWithNamedWildcard(token);
}
/*
* If the target node already exists, but is without a value,
* then the value should be updated.
*/
if (index == (path.length - 1)) {
if (node.value != null) {
node.value = updater.apply(node.value, value);
} else {
node.value = value;
}
}
}
node.insertOrUpdate(path, index + 1, value, updater);
}
private boolean isNamedWildcard(String key) {
return key.indexOf('{') != -1 && key.indexOf('}') != -1;
}
@ -136,23 +205,57 @@ public class PathTrie<T> {
return namedWildcard != null;
}
public T retrieve(String[] path, int index, Map<String, String> params) {
public T retrieve(String[] path, int index, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
if (index >= path.length)
return null;
String token = path[index];
TrieNode node = children.get(token);
boolean usedWildcard;
if (node == null) {
node = children.get(wildcard);
if (node == null) {
if (trieMatchingMode == TrieMatchingMode.WILDCARD_NODES_ALLOWED) {
node = children.get(wildcard);
if (node == null) {
return null;
}
usedWildcard = true;
} else if (trieMatchingMode == TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED && index == 1) {
/*
* Allow root node wildcard matches.
*/
node = children.get(wildcard);
if (node == null) {
return null;
}
usedWildcard = true;
} else if (trieMatchingMode == TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED && index + 1 == path.length) {
/*
* Allow leaf node wildcard matches.
*/
node = children.get(wildcard);
if (node == null) {
return null;
}
usedWildcard = true;
} else {
return null;
}
usedWildcard = true;
} else {
// If we are at the end of the path, the current node does not have a value but there
// is a child wildcard node, use the child wildcard node
if (index + 1 == path.length && node.value == null && children.get(wildcard) != null) {
if (index + 1 == path.length && node.value == null && children.get(wildcard) != null
&& EXPLICIT_OR_ROOT_WILDCARD.contains(trieMatchingMode) == false) {
/*
* If we are at the end of the path, the current node does not have a value but
* there is a child wildcard node, use the child wildcard node.
*/
node = children.get(wildcard);
usedWildcard = true;
} else if (index == 1 && node.value == null && children.get(wildcard) != null
&& trieMatchingMode == TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED) {
/*
* If we are at the root, and root wildcards are allowed, use the child wildcard
* node.
*/
node = children.get(wildcard);
usedWildcard = true;
} else {
@ -166,16 +269,16 @@ public class PathTrie<T> {
return node.value;
}
T res = node.retrieve(path, index + 1, params);
if (res == null && !usedWildcard) {
T nodeValue = node.retrieve(path, index + 1, params, trieMatchingMode);
if (nodeValue == null && !usedWildcard && trieMatchingMode != TrieMatchingMode.EXPLICIT_NODES_ONLY) {
node = children.get(wildcard);
if (node != null) {
put(params, node, token);
res = node.retrieve(path, index + 1, params);
nodeValue = node.retrieve(path, index + 1, params, trieMatchingMode);
}
}
return res;
return nodeValue;
}
private void put(Map<String, String> params, TrieNode node, String value) {
@ -200,18 +303,47 @@ public class PathTrie<T> {
return;
}
int index = 0;
// supports initial delimiter.
// Supports initial delimiter.
if (strings.length > 0 && strings[0].isEmpty()) {
index = 1;
}
root.insert(strings, index, value);
}
/**
* Insert a value for the given path. If the path already exists, replace the value with:
* <pre>
* value = updater.apply(oldValue, newValue);
* </pre>
* allowing the value to be updated if desired.
*/
public void insertOrUpdate(String path, T value, BiFunction<T, T, T> updater) {
String[] strings = path.split(SEPARATOR);
if (strings.length == 0) {
if (rootValue != null) {
rootValue = updater.apply(rootValue, value);
} else {
rootValue = value;
}
return;
}
int index = 0;
// Supports initial delimiter.
if (strings.length > 0 && strings[0].isEmpty()) {
index = 1;
}
root.insertOrUpdate(strings, index, value, updater);
}
public T retrieve(String path) {
return retrieve(path, null);
return retrieve(path, null, TrieMatchingMode.WILDCARD_NODES_ALLOWED);
}
public T retrieve(String path, Map<String, String> params) {
return retrieve(path, params, TrieMatchingMode.WILDCARD_NODES_ALLOWED);
}
public T retrieve(String path, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
if (path.length() == 0) {
return rootValue;
}
@ -220,10 +352,56 @@ public class PathTrie<T> {
return rootValue;
}
int index = 0;
// supports initial delimiter.
// Supports initial delimiter.
if (strings.length > 0 && strings[0].isEmpty()) {
index = 1;
}
return root.retrieve(strings, index, params);
return root.retrieve(strings, index, params, trieMatchingMode);
}
/**
* Returns an iterator of the objects stored in the {@code PathTrie}, using
* all possible {@code TrieMatchingMode} modes. The {@code paramSupplier}
* is called between each invocation of {@code next()} to supply a new map
* of parameters.
*/
public Iterator<T> retrieveAll(String path, Supplier<Map<String, String>> paramSupplier) {
return new PathTrieIterator<>(this, path, paramSupplier);
}
class PathTrieIterator<T> implements Iterator<T> {
private final List<TrieMatchingMode> modes;
private final Supplier<Map<String, String>> paramSupplier;
private final PathTrie<T> trie;
private final String path;
PathTrieIterator(PathTrie trie, String path, Supplier<Map<String, String>> paramSupplier) {
this.path = path;
this.trie = trie;
this.paramSupplier = paramSupplier;
this.modes = new ArrayList<>(Arrays.asList(TrieMatchingMode.EXPLICIT_NODES_ONLY,
TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED,
TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED,
TrieMatchingMode.WILDCARD_NODES_ALLOWED));
assert TrieMatchingMode.values().length == 4 : "missing trie matching mode";
}
@Override
public boolean hasNext() {
return modes.isEmpty() == false;
}
@Override
public T next() {
if (modes.isEmpty()) {
throw new NoSuchElementException("called next() without validating hasNext()! no more modes available");
}
TrieMatchingMode mode = modes.remove(0);
Map<String, String> params = paramSupplier.get();
return trie.retrieve(path, params, mode);
}
}
}

View File

@ -0,0 +1,80 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
/**
* Encapsulate multiple handlers for the same path, allowing different handlers for different HTTP verbs.
*/
final class MethodHandlers {
private final String path;
private final Map<RestRequest.Method, RestHandler> methodHandlers;
MethodHandlers(String path, RestHandler handler, RestRequest.Method... methods) {
this.path = path;
this.methodHandlers = new HashMap<>(methods.length);
for (RestRequest.Method method : methods) {
methodHandlers.put(method, handler);
}
}
/**
* Add an additional method and handler for an existing path. Note that {@code MethodHandlers}
* does not allow replacing the handler for an already existing method.
*/
public MethodHandlers addMethod(RestRequest.Method method, RestHandler handler) {
RestHandler existing = methodHandlers.putIfAbsent(method, handler);
if (existing != null) {
throw new IllegalArgumentException("Cannot replace existing handler for [" + path + "] for method: " + method);
}
return this;
}
/**
* Add a handler for an additional array of methods. Note that {@code MethodHandlers}
* does not allow replacing the handler for an already existing method.
*/
public MethodHandlers addMethods(RestHandler handler, RestRequest.Method... methods) {
for (RestRequest.Method method : methods) {
addMethod(method, handler);
}
return this;
}
/**
* Return an Optional-wrapped handler for a method, or an empty optional if
* there is no handler.
*/
public Optional<RestHandler> getHandler(RestRequest.Method method) {
return Optional.ofNullable(methodHandlers.get(method));
}
/**
* Return a set of all valid HTTP methods for the particular path
*/
public Set<RestRequest.Method> getValidMethods() {
return methodHandlers.keySet();
}
}

View File

@ -42,9 +42,14 @@ import org.elasticsearch.usage.UsageService;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
@ -58,12 +63,7 @@ import static org.elasticsearch.rest.RestStatus.OK;
public class RestController extends AbstractComponent implements HttpServerTransport.Dispatcher {
private final PathTrie<RestHandler> getHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> postHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> putHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> deleteHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> headHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> optionsHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<MethodHandlers> handlers = new PathTrie<>(RestUtils.REST_DECODER);
private final UnaryOperator<RestHandler> handlerWrapper;
@ -148,24 +148,19 @@ public class RestController extends AbstractComponent implements HttpServerTrans
* @param handler The handler to actually execute
*/
public void registerHandler(RestRequest.Method method, String path, RestHandler handler) {
PathTrie<RestHandler> handlers = getHandlersForMethod(method);
if (handlers != null) {
handlers.insert(path, handler);
if (handler instanceof BaseRestHandler) {
usageService.addRestHandler((BaseRestHandler) handler);
}
} else {
throw new IllegalArgumentException("Can't handle [" + method + "] for path [" + path + "]");
if (handler instanceof BaseRestHandler) {
usageService.addRestHandler((BaseRestHandler) handler);
}
handlers.insertOrUpdate(path, new MethodHandlers(path, handler, method), (mHandlers, newMHandler) -> {
return mHandlers.addMethods(handler, method);
});
}
/**
* @param request The current request. Must not be null.
* @return true iff the circuit breaker limit must be enforced for processing this request.
*/
public boolean canTripCircuitBreaker(RestRequest request) {
RestHandler handler = getHandler(request);
return (handler != null) ? handler.canTripCircuitBreaker() : true;
public boolean canTripCircuitBreaker(final Optional<RestHandler> handler) {
return handler.map(h -> h.canTripCircuitBreaker()).orElse(true);
}
@Override
@ -174,32 +169,11 @@ public class RestController extends AbstractComponent implements HttpServerTrans
handleFavicon(request, channel);
return;
}
RestChannel responseChannel = channel;
try {
final int contentLength = request.hasContent() ? request.content().length() : 0;
assert contentLength >= 0 : "content length was negative, how is that possible?";
final RestHandler handler = getHandler(request);
if (contentLength > 0 && hasContentType(request, handler) == false) {
sendContentTypeErrorMessage(request, responseChannel);
} else if (contentLength > 0 && handler != null && handler.supportsContentStream() &&
request.getXContentType() != XContentType.JSON && request.getXContentType() != XContentType.SMILE) {
responseChannel.sendResponse(BytesRestResponse.createSimpleErrorResponse(responseChannel,
RestStatus.NOT_ACCEPTABLE, "Content-Type [" + request.getXContentType() +
"] does not support stream parsing. Use JSON or SMILE instead"));
} else {
if (canTripCircuitBreaker(request)) {
inFlightRequestsBreaker(circuitBreakerService).addEstimateBytesAndMaybeBreak(contentLength, "<http_request>");
} else {
inFlightRequestsBreaker(circuitBreakerService).addWithoutBreaking(contentLength);
}
// iff we could reserve bytes for the request we need to send the response also over this channel
responseChannel = new ResourceHandlingHttpChannel(channel, circuitBreakerService, contentLength);
dispatchRequest(request, responseChannel, client, threadContext, handler);
}
tryAllHandlers(request, channel, threadContext);
} catch (Exception e) {
try {
responseChannel.sendResponse(new BytesRestResponse(channel, e));
channel.sendResponse(new BytesRestResponse(channel, e));
} catch (Exception inner) {
inner.addSuppressed(e);
logger.error((Supplier<?>) () ->
@ -233,33 +207,56 @@ public class RestController extends AbstractComponent implements HttpServerTrans
}
}
void dispatchRequest(final RestRequest request, final RestChannel channel, final NodeClient client, ThreadContext threadContext,
final RestHandler handler) throws Exception {
if (checkRequestParameters(request, channel) == false) {
channel
.sendResponse(BytesRestResponse.createSimpleErrorResponse(channel,BAD_REQUEST, "error traces in responses are disabled."));
} else {
for (String key : headersToCopy) {
String httpHeader = request.header(key);
if (httpHeader != null) {
threadContext.putHeader(key, httpHeader);
}
}
/**
* Dispatch the request, if possible, returning true if a response was sent or false otherwise.
*/
boolean dispatchRequest(final RestRequest request, final RestChannel channel, final NodeClient client,
ThreadContext threadContext, final Optional<RestHandler> mHandler) throws Exception {
final int contentLength = request.hasContent() ? request.content().length() : 0;
if (handler == null) {
if (request.method() == RestRequest.Method.OPTIONS) {
// when we have OPTIONS request, simply send OK by default (with the Access Control Origin header which gets automatically added)
RestChannel responseChannel = channel;
// Indicator of whether a response was sent or not
boolean requestHandled;
channel.sendResponse(new BytesRestResponse(OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
if (contentLength > 0 && mHandler.map(h -> hasContentType(request, h) == false).orElse(false)) {
sendContentTypeErrorMessage(request, channel);
requestHandled = true;
} else if (contentLength > 0 && mHandler.map(h -> h.supportsContentStream()).orElse(false) &&
request.getXContentType() != XContentType.JSON && request.getXContentType() != XContentType.SMILE) {
channel.sendResponse(BytesRestResponse.createSimpleErrorResponse(channel,
RestStatus.NOT_ACCEPTABLE, "Content-Type [" + request.getXContentType() +
"] does not support stream parsing. Use JSON or SMILE instead"));
requestHandled = true;
} else if (mHandler.isPresent()) {
try {
if (canTripCircuitBreaker(mHandler)) {
inFlightRequestsBreaker(circuitBreakerService).addEstimateBytesAndMaybeBreak(contentLength, "<http_request>");
} else {
final String msg = "No handler found for uri [" + request.uri() + "] and method [" + request.method() + "]";
channel.sendResponse(new BytesRestResponse(BAD_REQUEST, msg));
inFlightRequestsBreaker(circuitBreakerService).addWithoutBreaking(contentLength);
}
// iff we could reserve bytes for the request we need to send the response also over this channel
responseChannel = new ResourceHandlingHttpChannel(channel, circuitBreakerService, contentLength);
final RestHandler wrappedHandler = mHandler.map(h -> handlerWrapper.apply(h)).get();
wrappedHandler.handleRequest(request, responseChannel, client);
requestHandled = true;
} catch (Exception e) {
responseChannel.sendResponse(new BytesRestResponse(responseChannel, e));
// We "handled" the request by returning a response, even though it was an error
requestHandled = true;
}
} else {
if (request.method() == RestRequest.Method.OPTIONS) {
// when we have OPTIONS request, simply send OK by default (with the Access Control Origin header which gets automatically added)
channel.sendResponse(new BytesRestResponse(OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
requestHandled = true;
} else {
final RestHandler wrappedHandler = Objects.requireNonNull(handlerWrapper.apply(handler));
wrappedHandler.handleRequest(request, channel, client);
requestHandled = false;
}
}
// Return true if the request was handled, false otherwise.
return requestHandled;
}
/**
@ -308,32 +305,69 @@ public class RestController extends AbstractComponent implements HttpServerTrans
return true;
}
private RestHandler getHandler(RestRequest request) {
String path = getPath(request);
PathTrie<RestHandler> handlers = getHandlersForMethod(request.method());
if (handlers != null) {
return handlers.retrieve(path, request.params());
} else {
return null;
void tryAllHandlers(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) throws Exception {
for (String key : headersToCopy) {
String httpHeader = request.header(key);
if (httpHeader != null) {
threadContext.putHeader(key, httpHeader);
}
}
// Request execution flag
boolean requestHandled = false;
if (checkRequestParameters(request, channel) == false) {
channel.sendResponse(BytesRestResponse.createSimpleErrorResponse(channel,
BAD_REQUEST, "error traces in responses are disabled."));
requestHandled = true;
}
// Loop through all possible handlers, attempting to dispatch the request
Iterator<MethodHandlers> allHandlers = getAllHandlers(request);
for (Iterator<MethodHandlers> it = allHandlers; it.hasNext(); ) {
final Optional<RestHandler> mHandler = Optional.ofNullable(it.next()).flatMap(mh -> mh.getHandler(request.method()));
requestHandled = dispatchRequest(request, channel, client, threadContext, mHandler);
if (requestHandled) {
break;
}
}
// If request has not been handled, fallback to a bad request error.
if (requestHandled == false) {
handleBadRequest(request, channel);
}
}
private PathTrie<RestHandler> getHandlersForMethod(RestRequest.Method method) {
if (method == RestRequest.Method.GET) {
return getHandlers;
} else if (method == RestRequest.Method.POST) {
return postHandlers;
} else if (method == RestRequest.Method.PUT) {
return putHandlers;
} else if (method == RestRequest.Method.DELETE) {
return deleteHandlers;
} else if (method == RestRequest.Method.HEAD) {
return headHandlers;
} else if (method == RestRequest.Method.OPTIONS) {
return optionsHandlers;
} else {
return null;
Iterator<MethodHandlers> getAllHandlers(final RestRequest request) {
// Between retrieving the correct path, we need to reset the parameters,
// otherwise parameters are parsed out of the URI that aren't actually handled.
final Map<String, String> originalParams = new HashMap<>(request.params());
return handlers.retrieveAll(getPath(request), () -> {
// PathTrie modifies the request, so reset the params between each iteration
request.params().clear();
request.params().putAll(originalParams);
return request.params();
});
}
/**
* Handle a requests with no candidate handlers (return a 400 Bad Request
* error).
*/
private void handleBadRequest(RestRequest request, RestChannel channel) {
channel.sendResponse(new BytesRestResponse(BAD_REQUEST,
"No handler found for uri [" + request.uri() + "] and method [" + request.method() + "]"));
}
/**
* Get the valid set of HTTP methods for a REST request.
*/
private Set<RestRequest.Method> getValidHandlerMethodSet(RestRequest request) {
Set<RestRequest.Method> validMethods = new HashSet<>();
Iterator<MethodHandlers> allHandlers = getAllHandlers(request);
for (Iterator<MethodHandlers> it = allHandlers; it.hasNext(); ) {
Optional.ofNullable(it.next()).map(mh -> validMethods.addAll(mh.getValidMethods()));
}
return validMethods;
}
private String getPath(RestRequest request) {

View File

@ -36,7 +36,7 @@ import static org.elasticsearch.rest.RestRequest.Method.GET;
public class RestGetTaskAction extends BaseRestHandler {
public RestGetTaskAction(Settings settings, RestController controller) {
super(settings);
controller.registerHandler(GET, "/_tasks/{taskId}", this);
controller.registerHandler(GET, "/_tasks/{task_id}", this);
}
@Override
@ -46,7 +46,7 @@ public class RestGetTaskAction extends BaseRestHandler {
@Override
public RestChannelConsumer prepareRequest(final RestRequest request, final NodeClient client) throws IOException {
TaskId taskId = new TaskId(request.param("taskId"));
TaskId taskId = new TaskId(request.param("task_id"));
boolean waitForCompletion = request.paramAsBoolean("wait_for_completion", false);
TimeValue timeout = request.paramAsTime("timeout", null);

View File

@ -119,7 +119,7 @@ public class ActionModuleTests extends ESTestCase {
// At this point the easiest way to confirm that a handler is loaded is to try to register another one on top of it and to fail
Exception e = expectThrows(IllegalArgumentException.class, () ->
actionModule.getRestController().registerHandler(Method.GET, "/", null));
assertThat(e.getMessage(), startsWith("Path [/] already has a value [" + RestMainAction.class.getName()));
assertThat(e.getMessage(), startsWith("Cannot replace existing handler for [/] for method: GET"));
}
public void testPluginCantOverwriteBuiltinRestHandler() throws IOException {
@ -139,7 +139,7 @@ public class ActionModuleTests extends ESTestCase {
settings.getIndexScopedSettings(), settings.getClusterSettings(), settings.getSettingsFilter(), threadPool,
singletonList(dupsMainAction), null, null, usageService);
Exception e = expectThrows(IllegalArgumentException.class, () -> actionModule.initRestHandlers(null));
assertThat(e.getMessage(), startsWith("Path [/] already has a value [" + RestMainAction.class.getName()));
assertThat(e.getMessage(), startsWith("Cannot replace existing handler for [/] for method: GET"));
} finally {
threadPool.shutdown();
}
@ -174,7 +174,7 @@ public class ActionModuleTests extends ESTestCase {
// At this point the easiest way to confirm that a handler is loaded is to try to register another one on top of it and to fail
Exception e = expectThrows(IllegalArgumentException.class, () ->
actionModule.getRestController().registerHandler(Method.GET, "/_dummy", null));
assertThat(e.getMessage(), startsWith("Path [/_dummy] already has a value [" + FakeHandler.class.getName()));
assertThat(e.getMessage(), startsWith("Cannot replace existing handler for [/_dummy] for method: GET"));
} finally {
threadPool.shutdown();
}

View File

@ -23,11 +23,14 @@ import org.elasticsearch.rest.RestUtils;
import org.elasticsearch.test.ESTestCase;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import org.elasticsearch.common.path.PathTrie.TrieMatchingMode;
public class PathTrieTests extends ESTestCase {
public static final PathTrie.Decoder NO_DECODER = new PathTrie.Decoder() {
@ -114,6 +117,102 @@ public class PathTrieTests extends ESTestCase {
assertThat(trie.retrieve("/v/x/c", params), equalTo("test6"));
}
// https://github.com/elastic/elasticsearch/pull/17916
public void testWildcardMatchingModes() {
PathTrie<String> trie = new PathTrie<>(NO_DECODER);
trie.insert("{testA}", "test1");
trie.insert("{testA}/{testB}", "test2");
trie.insert("a/{testA}", "test3");
trie.insert("{testA}/b", "test4");
trie.insert("{testA}/b/c", "test5");
trie.insert("a/{testB}/c", "test6");
trie.insert("a/b/{testC}", "test7");
trie.insert("{testA}/b/{testB}", "test8");
trie.insert("x/{testA}/z", "test9");
trie.insert("{testA}/{testB}/{testC}", "test10");
Map<String, String> params = new HashMap<>();
assertThat(trie.retrieve("/a", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/a", params, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED), equalTo("test1"));
assertThat(trie.retrieve("/a", params, TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED), equalTo("test1"));
assertThat(trie.retrieve("/a", params, TrieMatchingMode.WILDCARD_NODES_ALLOWED), equalTo("test1"));
Iterator<String> allPaths = trie.retrieveAll("/a", () -> params);
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo("test1"));
assertThat(allPaths.next(), equalTo("test1"));
assertThat(allPaths.next(), equalTo("test1"));
assertFalse(allPaths.hasNext());
assertThat(trie.retrieve("/a/b", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/a/b", params, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED), equalTo("test4"));
assertThat(trie.retrieve("/a/b", params, TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED), equalTo("test3"));
assertThat(trie.retrieve("/a/b", params, TrieMatchingMode.WILDCARD_NODES_ALLOWED), equalTo("test3"));
allPaths = trie.retrieveAll("/a/b", () -> params);
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo("test4"));
assertThat(allPaths.next(), equalTo("test3"));
assertThat(allPaths.next(), equalTo("test3"));
assertFalse(allPaths.hasNext());
assertThat(trie.retrieve("/a/b/c", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/a/b/c", params, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED), equalTo("test5"));
assertThat(trie.retrieve("/a/b/c", params, TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED), equalTo("test7"));
assertThat(trie.retrieve("/a/b/c", params, TrieMatchingMode.WILDCARD_NODES_ALLOWED), equalTo("test7"));
allPaths = trie.retrieveAll("/a/b/c", () -> params);
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo("test5"));
assertThat(allPaths.next(), equalTo("test7"));
assertThat(allPaths.next(), equalTo("test7"));
assertFalse(allPaths.hasNext());
assertThat(trie.retrieve("/x/y/z", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/x/y/z", params, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED), nullValue());
assertThat(trie.retrieve("/x/y/z", params, TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED), nullValue());
assertThat(trie.retrieve("/x/y/z", params, TrieMatchingMode.WILDCARD_NODES_ALLOWED), equalTo("test9"));
allPaths = trie.retrieveAll("/x/y/z", () -> params);
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo("test9"));
assertFalse(allPaths.hasNext());
assertThat(trie.retrieve("/d/e/f", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/d/e/f", params, TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED), nullValue());
assertThat(trie.retrieve("/d/e/f", params, TrieMatchingMode.WILDCARD_LEAF_NODES_ALLOWED), nullValue());
assertThat(trie.retrieve("/d/e/f", params, TrieMatchingMode.WILDCARD_NODES_ALLOWED), equalTo("test10"));
allPaths = trie.retrieveAll("/d/e/f", () -> params);
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo(null));
assertThat(allPaths.next(), equalTo("test10"));
assertFalse(allPaths.hasNext());
}
// https://github.com/elastic/elasticsearch/pull/17916
public void testExplicitMatchingMode() {
PathTrie<String> trie = new PathTrie<>(NO_DECODER);
trie.insert("{testA}", "test1");
trie.insert("a", "test2");
trie.insert("{testA}/{testB}", "test3");
trie.insert("a/{testB}", "test4");
trie.insert("{testB}/b", "test5");
trie.insert("a/b", "test6");
trie.insert("{testA}/b/{testB}", "test7");
trie.insert("x/{testA}/z", "test8");
trie.insert("{testA}/{testB}/{testC}", "test9");
trie.insert("a/b/c", "test10");
Map<String, String> params = new HashMap<>();
assertThat(trie.retrieve("/a", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), equalTo("test2"));
assertThat(trie.retrieve("/x", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/a/b", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), equalTo("test6"));
assertThat(trie.retrieve("/a/x", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
assertThat(trie.retrieve("/a/b/c", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), equalTo("test10"));
assertThat(trie.retrieve("/x/y/z", params, TrieMatchingMode.EXPLICIT_NODES_ONLY), nullValue());
}
public void testSamePathConcreteResolution() {
PathTrie<String> trie = new PathTrie<>(NO_DECODER);
trie.insert("{x}/{y}/{z}", "test1");

View File

@ -25,6 +25,7 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
@ -49,8 +50,10 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
@ -61,7 +64,9 @@ import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doCallRealMethod;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class RestControllerTests extends ESTestCase {
@ -95,7 +100,6 @@ public class RestControllerTests extends ESTestCase {
httpServerTransport.start();
}
public void testApplyRelevantHeaders() throws Exception {
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
Set<String> headers = new HashSet<>(Arrays.asList("header.1", "header.2"));
@ -104,12 +108,26 @@ public class RestControllerTests extends ESTestCase {
restHeaders.put("header.1", Collections.singletonList("true"));
restHeaders.put("header.2", Collections.singletonList("true"));
restHeaders.put("header.3", Collections.singletonList("false"));
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build(), null, null,
threadContext, (RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
});
RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build();
final RestController spyRestController = spy(restController);
when(spyRestController.getAllHandlers(fakeRequest))
.thenReturn(new Iterator<MethodHandlers>() {
@Override
public boolean hasNext() {
return false;
}
@Override
public MethodHandlers next() {
return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
}, RestRequest.Method.GET);
}
});
AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST);
restController.dispatchRequest(fakeRequest, channel, threadContext);
// the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the
// context in this test
assertEquals("true", threadContext.getHeader("header.1"));
@ -124,10 +142,22 @@ public class RestControllerTests extends ESTestCase {
controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));
assertTrue(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/trip").build()));
RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withPath("/trip").build();
for (Iterator<MethodHandlers> it = controller.getAllHandlers(fakeRequest); it.hasNext(); ) {
Optional<MethodHandlers> mHandler = Optional.ofNullable(it.next());
assertTrue(mHandler.map(mh -> controller.canTripCircuitBreaker(mh.getHandler(RestRequest.Method.GET))).orElse(true));
}
// assume trip even on unknown paths
assertTrue(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/unknown-path").build()));
assertFalse(controller.canTripCircuitBreaker(new FakeRestRequest.Builder(xContentRegistry()).withPath("/do-not-trip").build()));
fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withPath("/unknown-path").build();
for (Iterator<MethodHandlers> it = controller.getAllHandlers(fakeRequest); it.hasNext(); ) {
Optional<MethodHandlers> mHandler = Optional.ofNullable(it.next());
assertTrue(mHandler.map(mh -> controller.canTripCircuitBreaker(mh.getHandler(RestRequest.Method.GET))).orElse(true));
}
fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withPath("/do-not-trip").build();
for (Iterator<MethodHandlers> it = controller.getAllHandlers(fakeRequest); it.hasNext(); ) {
Optional<MethodHandlers> mHandler = Optional.ofNullable(it.next());
assertFalse(mHandler.map(mh -> controller.canTripCircuitBreaker(mh.getHandler(RestRequest.Method.GET))).orElse(false));
}
}
public void testRegisterAsDeprecatedHandler() {
@ -182,7 +212,8 @@ public class RestControllerTests extends ESTestCase {
final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null,
circuitBreakerService, usageService);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, threadContext, handler);
restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(),
null, null, threadContext, Optional.of(handler));
assertTrue(wrapperCalled.get());
assertFalse(handlerCalled.get());
}