diff --git a/core/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java b/core/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java index 11723c3d50a..583bcbc561d 100644 --- a/core/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java +++ b/core/src/main/java/org/elasticsearch/index/shard/SearchOperationListener.java @@ -21,6 +21,7 @@ package org.elasticsearch.index.shard; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.search.internal.SearchContext; import java.util.List; @@ -104,6 +105,14 @@ public interface SearchOperationListener { */ default void onFreeScrollContext(SearchContext context) {}; + /** + * Executed prior to using a {@link SearchContext} that has been retrieved + * from the active contexts. If the context is deemed invalid a runtime + * exception can be thrown, which will prevent the context from being used. + * @param context the context retrieved from the active contexts + */ + default void validateSearchContext(SearchContext context) {} + /** * A Composite listener that multiplexes calls to each of the listeners methods. */ @@ -225,5 +234,18 @@ public interface SearchOperationListener { } } } + + @Override + public void validateSearchContext(SearchContext context) { + Exception exception = null; + for (SearchOperationListener listener : listeners) { + try { + listener.validateSearchContext(context); + } catch (Exception e) { + exception = ExceptionsHelper.useOrSuppress(exception, e); + } + } + ExceptionsHelper.reThrowIfNotNull(exception); + } } } diff --git a/core/src/main/java/org/elasticsearch/search/SearchService.java b/core/src/main/java/org/elasticsearch/search/SearchService.java index f9d0b3dc338..4174da37243 100644 --- a/core/src/main/java/org/elasticsearch/search/SearchService.java +++ b/core/src/main/java/org/elasticsearch/search/SearchService.java @@ -437,7 +437,15 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv if (context == null) { throw new SearchContextMissingException(id); } - return context; + + SearchOperationListener operationListener = context.indexShard().getSearchOperationListener(); + try { + operationListener.validateSearchContext(context); + return context; + } catch (Exception e) { + processFailure(context, e); + throw e; + } } final SearchContext createAndPutContext(ShardSearchRequest request) throws IOException { diff --git a/core/src/test/java/org/elasticsearch/index/shard/SearchOperationListenerTests.java b/core/src/test/java/org/elasticsearch/index/shard/SearchOperationListenerTests.java index 1721e5f5e5d..fafdbe6755b 100644 --- a/core/src/test/java/org/elasticsearch/index/shard/SearchOperationListenerTests.java +++ b/core/src/test/java/org/elasticsearch/index/shard/SearchOperationListenerTests.java @@ -18,9 +18,6 @@ */ package org.elasticsearch.index.shard; -import org.apache.lucene.index.Term; -import org.elasticsearch.client.Client; -import org.elasticsearch.index.engine.Engine; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TestSearchContext; @@ -32,6 +29,9 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + public class SearchOperationListenerTests extends ESTestCase { // this test also tests if calls are correct if one or more listeners throw exceptions @@ -46,6 +46,7 @@ public class SearchOperationListenerTests extends ESTestCase { AtomicInteger freeContext = new AtomicInteger(); AtomicInteger newScrollContext = new AtomicInteger(); AtomicInteger freeScrollContext = new AtomicInteger(); + AtomicInteger validateSearchContext = new AtomicInteger(); AtomicInteger timeInNanos = new AtomicInteger(randomIntBetween(0, 10)); SearchOperationListener listener = new SearchOperationListener() { @Override @@ -109,17 +110,26 @@ public class SearchOperationListenerTests extends ESTestCase { assertNotNull(context); freeScrollContext.incrementAndGet(); } + + @Override + public void validateSearchContext(SearchContext context) { + assertNotNull(context); + validateSearchContext.incrementAndGet(); + } }; SearchOperationListener throwingListener = (SearchOperationListener) Proxy.newProxyInstance( SearchOperationListener.class.getClassLoader(), new Class[]{SearchOperationListener.class}, (a,b,c) -> { throw new RuntimeException();}); + int throwingListeners = 0; final List indexingOperationListeners = new ArrayList<>(Arrays.asList(listener, listener)); if (randomBoolean()) { indexingOperationListeners.add(throwingListener); + throwingListeners++; if (randomBoolean()) { indexingOperationListeners.add(throwingListener); + throwingListeners++; } } Collections.shuffle(indexingOperationListeners, random()); @@ -137,6 +147,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onFetchPhase(ctx, timeInNanos.get()); assertEquals(0, preFetch.get()); @@ -149,6 +160,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onPreQueryPhase(ctx); assertEquals(0, preFetch.get()); @@ -161,6 +173,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onPreFetchPhase(ctx); assertEquals(2, preFetch.get()); @@ -173,6 +186,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onFailedFetchPhase(ctx); assertEquals(2, preFetch.get()); @@ -185,6 +199,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onFailedQueryPhase(ctx); assertEquals(2, preFetch.get()); @@ -197,6 +212,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onNewContext(ctx); assertEquals(2, preFetch.get()); @@ -209,6 +225,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(0, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onNewScrollContext(ctx); assertEquals(2, preFetch.get()); @@ -221,6 +238,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(2, newScrollContext.get()); assertEquals(0, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onFreeContext(ctx); assertEquals(2, preFetch.get()); @@ -233,6 +251,7 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(2, newScrollContext.get()); assertEquals(2, freeContext.get()); assertEquals(0, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); compositeListener.onFreeScrollContext(ctx); assertEquals(2, preFetch.get()); @@ -245,5 +264,28 @@ public class SearchOperationListenerTests extends ESTestCase { assertEquals(2, newScrollContext.get()); assertEquals(2, freeContext.get()); assertEquals(2, freeScrollContext.get()); + assertEquals(0, validateSearchContext.get()); + + if (throwingListeners == 0) { + compositeListener.validateSearchContext(ctx); + } else { + RuntimeException expected = expectThrows(RuntimeException.class, () -> compositeListener.validateSearchContext(ctx)); + assertNull(expected.getMessage()); + assertEquals(throwingListeners - 1, expected.getSuppressed().length); + if (throwingListeners > 1) { + assertThat(expected.getSuppressed()[0], not(sameInstance(expected))); + } + } + assertEquals(2, preFetch.get()); + assertEquals(2, preQuery.get()); + assertEquals(2, failedFetch.get()); + assertEquals(2, failedQuery.get()); + assertEquals(2, onQuery.get()); + assertEquals(2, onFetch.get()); + assertEquals(2, newContext.get()); + assertEquals(2, newScrollContext.get()); + assertEquals(2, freeContext.get()); + assertEquals(2, freeScrollContext.get()); + assertEquals(2, validateSearchContext.get()); } }