Ensure validation of the reader context is executed first (#61831)

This change makes sure that reader context is validated (`SearchOperationListener#validateReaderContext)
before any other operation and that it is correctly recycled or removed at the end of the operation.
This commit also fixes a race condition bug that would allocate the security reader for scrolls more than once.

Relates #61446

Co-authored-by: Nhat Nguyen <nhat.nguyen@elastic.co>
This commit is contained in:
Jim Ferenczi 2020-09-07 10:47:24 +02:00 committed by Nhat Nguyen
parent 44bd4a6004
commit 4d528e91a1
10 changed files with 146 additions and 130 deletions

View File

@ -547,24 +547,27 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
} }
/** /**
* This method should be called if a search phase failed to ensure all relevant search contexts and resources are released. * This method should be called if a search phase failed to ensure all relevant reader contexts are released.
* this method will also notify the listener and sends back a failure to the user. * This method will also notify the listener and sends back a failure to the user.
* *
* @param exception the exception explaining or causing the phase failure * @param exception the exception explaining or causing the phase failure
*/ */
private void raisePhaseFailure(SearchPhaseExecutionException exception) { private void raisePhaseFailure(SearchPhaseExecutionException exception) {
results.getSuccessfulResults().forEach((entry) -> { // we don't release persistent readers (point in time).
if (entry.getContextId() != null) { if (request.pointInTimeBuilder() == null) {
try { results.getSuccessfulResults().forEach((entry) -> {
SearchShardTarget searchShardTarget = entry.getSearchShardTarget(); if (entry.getContextId() != null) {
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); try {
sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices()); SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
} catch (Exception inner) { Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
inner.addSuppressed(exception); sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices());
logger.trace("failed to release context", inner); } catch (Exception inner) {
inner.addSuppressed(exception);
logger.trace("failed to release context", inner);
}
} }
} });
}); }
listener.onFailure(exception); listener.onFailure(exception);
} }

View File

@ -98,11 +98,13 @@ final class DfsQueryPhase extends SearchPhase {
progressListener.notifyQueryFailure(shardIndex, searchShardTarget, exception); progressListener.notifyQueryFailure(shardIndex, searchShardTarget, exception);
counter.onFailure(shardIndex, searchShardTarget, exception); counter.onFailure(shardIndex, searchShardTarget, exception);
} finally { } finally {
// the query might not have been executed at all (for example because thread pool rejected if (context.getRequest().pointInTimeBuilder() == null) {
// execution) and the search context that was created in dfs phase might not be released. // the query might not have been executed at all (for example because thread pool rejected
// release it again to be in the safe side // execution) and the search context that was created in dfs phase might not be released.
context.sendReleaseSearchContext( // release it again to be in the safe side
querySearchRequest.contextId(), connection, searchShardTarget.getOriginalIndices()); context.sendReleaseSearchContext(
querySearchRequest.contextId(), connection, searchShardTarget.getOriginalIndices());
}
} }
} }
}); });

View File

@ -206,11 +206,11 @@ final class FetchSearchPhase extends SearchPhase {
* Releases shard targets that are not used in the docsIdsToLoad. * Releases shard targets that are not used in the docsIdsToLoad.
*/ */
private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) { private void releaseIrrelevantSearchContext(QuerySearchResult queryResult) {
// we only release search context that we did not fetch from if we are not scrolling // we only release search context that we did not fetch from, if we are not scrolling
// and if it has at lease one hit that didn't make it to the global topDocs // or using a PIT and if it has at least one hit that didn't make it to the global topDocs
if (context.getRequest().scroll() == null && if (queryResult.hasSearchContext()
context.getRequest().pointInTimeBuilder() == null && && context.getRequest().scroll() == null
queryResult.hasSearchContext()) { && context.getRequest().pointInTimeBuilder() == null) {
try { try {
SearchShardTarget searchShardTarget = queryResult.getSearchShardTarget(); SearchShardTarget searchShardTarget = queryResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId()); Transport.Connection connection = context.getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());

View File

@ -113,7 +113,7 @@ public interface SearchOperationListener {
* @param readerContext The reader context used by this request. * @param readerContext The reader context used by this request.
* @param transportRequest the request that is going to use the search context * @param transportRequest the request that is going to use the search context
*/ */
default void validateSearchContext(ReaderContext readerContext, TransportRequest transportRequest) {} default void validateReaderContext(ReaderContext readerContext, TransportRequest transportRequest) {}
/** /**
* A Composite listener that multiplexes calls to each of the listeners methods. * A Composite listener that multiplexes calls to each of the listeners methods.
@ -238,11 +238,11 @@ public interface SearchOperationListener {
} }
@Override @Override
public void validateSearchContext(ReaderContext readerContext, TransportRequest request) { public void validateReaderContext(ReaderContext readerContext, TransportRequest request) {
Exception exception = null; Exception exception = null;
for (SearchOperationListener listener : listeners) { for (SearchOperationListener listener : listeners) {
try { try {
listener.validateSearchContext(readerContext, request); listener.validateReaderContext(readerContext, request);
} catch (Exception e) { } catch (Exception e) {
exception = ExceptionsHelper.useOrSuppress(exception, e); exception = ExceptionsHelper.useOrSuppress(exception, e);
} }

View File

@ -118,6 +118,7 @@ import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
import org.elasticsearch.threadpool.Scheduler.Cancellable; import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPool.Names; import org.elasticsearch.threadpool.ThreadPool.Names;
import org.elasticsearch.transport.TransportRequest;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
@ -353,7 +354,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
return context.dfsResult(); return context.dfsResult();
} catch (Exception e) { } catch (Exception e) {
logger.trace("Dfs phase failed", e); logger.trace("Dfs phase failed", e);
processFailure(request, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
} }
@ -396,12 +397,12 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
try (Releasable ignored = markAsUsed) { try (Releasable ignored = markAsUsed) {
listener.onFailure(exc); listener.onFailure(exc);
} finally { } finally {
processFailure(request, readerContext, exc); processFailure(readerContext, exc);
} }
return; return;
} }
if (canRewriteToMatchNone(canMatchRequest.source()) if (canRewriteToMatchNone(canMatchRequest.source())
&& canMatchRequest.source().query() instanceof MatchNoneQueryBuilder) { && canMatchRequest.source().query() instanceof MatchNoneQueryBuilder) {
try (Releasable ignored = markAsUsed) { try (Releasable ignored = markAsUsed) {
if (orig.readerId() == null) { if (orig.readerId() == null) {
try { try {
@ -420,17 +421,8 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
} }
// fork the execution in the search thread pool // fork the execution in the search thread pool
runAsync(getExecutor(shard), () -> { runAsync(getExecutor(shard), () -> executeQueryPhase(orig, task, readerContext),
try (Releasable ignored = markAsUsed) { wrapFailureListener(listener, readerContext, markAsUsed));
return executeQueryPhase(orig, task, readerContext);
}
}, ActionListener.wrap(listener::onResponse, exc -> {
try (Releasable ignored = markAsUsed) {
listener.onFailure(exc);
} finally {
processFailure(request, readerContext, exc);
}
}));
} }
@Override @Override
@ -442,7 +434,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
private IndexShard getShard(ShardSearchRequest request) { private IndexShard getShard(ShardSearchRequest request) {
if (request.readerId() != null) { if (request.readerId() != null) {
return findReaderContext(request.readerId()).indexShard(); return findReaderContext(request.readerId(), request).indexShard();
} else { } else {
return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id()); return indicesService.indexServiceSafe(request.shardId().getIndex()).getShard(request.shardId().id());
} }
@ -481,7 +473,7 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
(Exception) e.getCause() : new ElasticsearchException(e.getCause()); (Exception) e.getCause() : new ElasticsearchException(e.getCause());
} }
logger.trace("Query phase failed", e); logger.trace("Query phase failed", e);
processFailure(request, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
} }
@ -501,13 +493,12 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
public void executeQueryPhase(InternalScrollSearchRequest request, public void executeQueryPhase(InternalScrollSearchRequest request,
SearchShardTask task, SearchShardTask task,
ActionListener<ScrollQuerySearchResult> listener) { ActionListener<ScrollQuerySearchResult> listener) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId()); final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed = readerContext.markAsUsed();
runAsync(getExecutor(readerContext.indexShard()), () -> { runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (Releasable ignored = readerContext.markAsUsed(); try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) { SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) {
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request);
if (request.scroll() != null && request.scroll().keepAlive() != null) { if (request.scroll() != null && request.scroll().keepAlive() != null) {
final long keepAlive = request.scroll().keepAlive().millis(); final long keepAlive = request.scroll().keepAlive().millis();
checkKeepAliveLimit(keepAlive); checkKeepAliveLimit(keepAlive);
@ -521,21 +512,20 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
return new ScrollQuerySearchResult(searchContext.queryResult(), searchContext.shardTarget()); return new ScrollQuerySearchResult(searchContext.queryResult(), searchContext.shardTarget());
} catch (Exception e) { } catch (Exception e) {
logger.trace("Query phase failed", e); logger.trace("Query phase failed", e);
processFailure(shardSearchRequest, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
}, listener); }, ActionListener.runAfter(listener, markAsUsed::close));
} }
public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, ActionListener<QuerySearchResult> listener) { public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, ActionListener<QuerySearchResult> listener) {
final ReaderContext readerContext = findReaderContext(request.contextId()); final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final Releasable markAsUsed = readerContext.markAsUsed();
runAsync(getExecutor(readerContext.indexShard()), () -> { runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
readerContext.setAggregatedDfs(request.dfs()); readerContext.setAggregatedDfs(request.dfs());
try (Releasable ignored = readerContext.markAsUsed(); try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, true);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, true);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) { SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) {
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request);
searchContext.searcher().setAggregatedDfs(request.dfs()); searchContext.searcher().setAggregatedDfs(request.dfs());
queryPhase.execute(searchContext); queryPhase.execute(searchContext);
if (searchContext.queryResult().hasSearchContext() == false && readerContext.singleSession()) { if (searchContext.queryResult().hasSearchContext() == false && readerContext.singleSession()) {
@ -552,10 +542,10 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
} catch (Exception e) { } catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
logger.trace("Query phase failed", e); logger.trace("Query phase failed", e);
processFailure(shardSearchRequest, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
}, listener); }, wrapFailureListener(listener, readerContext, markAsUsed));
} }
private Executor getExecutor(IndexShard indexShard) { private Executor getExecutor(IndexShard indexShard) {
@ -573,13 +563,12 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
public void executeFetchPhase(InternalScrollSearchRequest request, SearchShardTask task, public void executeFetchPhase(InternalScrollSearchRequest request, SearchShardTask task,
ActionListener<ScrollQueryFetchSearchResult> listener) { ActionListener<ScrollQueryFetchSearchResult> listener) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId()); final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed = readerContext.markAsUsed();
runAsync(getExecutor(readerContext.indexShard()), () -> { runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (Releasable ignored = readerContext.markAsUsed(); try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) { SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)) {
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request);
if (request.scroll() != null && request.scroll().keepAlive() != null) { if (request.scroll() != null && request.scroll().keepAlive() != null) {
checkKeepAliveLimit(request.scroll().keepAlive().millis()); checkKeepAliveLimit(request.scroll().keepAlive().millis());
readerContext.keepAlive(request.scroll().keepAlive().millis()); readerContext.keepAlive(request.scroll().keepAlive().millis());
@ -594,19 +583,18 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
} catch (Exception e) { } catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
logger.trace("Fetch phase failed", e); logger.trace("Fetch phase failed", e);
processFailure(shardSearchRequest, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
}, listener); }, ActionListener.runAfter(listener, markAsUsed::close));
} }
public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener<FetchSearchResult> listener) { public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener<FetchSearchResult> listener) {
final ReaderContext readerContext = findReaderContext(request.contextId()); final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final Releasable markAsUsed = readerContext.markAsUsed();
runAsync(getExecutor(readerContext.indexShard()), () -> { runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
try (Releasable ignored = readerContext.markAsUsed(); try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request);
if (request.lastEmittedDoc() != null) { if (request.lastEmittedDoc() != null) {
searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc(); searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc();
} }
@ -625,10 +613,10 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
} catch (Exception e) { } catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e); assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
logger.trace("Fetch phase failed", e); logger.trace("Fetch phase failed", e);
processFailure(shardSearchRequest, readerContext, e); processFailure(readerContext, e);
throw e; throw e;
} }
}, listener); }, wrapFailureListener(listener, readerContext, markAsUsed));
} }
private ReaderContext getReaderContext(ShardSearchContextId id) { private ReaderContext getReaderContext(ShardSearchContextId id) {
@ -642,19 +630,24 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
return null; return null;
} }
private ReaderContext findReaderContext(ShardSearchContextId id) throws SearchContextMissingException { private ReaderContext findReaderContext(ShardSearchContextId id, TransportRequest request) throws SearchContextMissingException {
final ReaderContext reader = getReaderContext(id); final ReaderContext reader = getReaderContext(id);
if (reader == null) { if (reader == null) {
throw new SearchContextMissingException(id); throw new SearchContextMissingException(id);
} }
try {
reader.validate(request);
} catch (Exception exc) {
processFailure(reader, exc);
throw exc;
}
return reader; return reader;
} }
final ReaderContext createOrGetReaderContext(ShardSearchRequest request, boolean keepStatesInContext) { final ReaderContext createOrGetReaderContext(ShardSearchRequest request, boolean keepStatesInContext) {
if (request.readerId() != null) { if (request.readerId() != null) {
assert keepStatesInContext == false; assert keepStatesInContext == false;
final ReaderContext readerContext = findReaderContext(request.readerId()); final ReaderContext readerContext = findReaderContext(request.readerId(), request);
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request);
final long keepAlive = request.keepAlive().millis(); final long keepAlive = request.keepAlive().millis();
checkKeepAliveLimit(keepAlive); checkKeepAliveLimit(keepAlive);
readerContext.keepAlive(keepAlive); readerContext.keepAlive(keepAlive);
@ -860,17 +853,38 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
} }
} }
private void processFailure(ShardSearchRequest request, ReaderContext context, Exception e) { private <T> ActionListener<T> wrapFailureListener(ActionListener<T> listener, ReaderContext context, Releasable releasable) {
if (context.singleSession() || request.scroll() != null) { return new ActionListener<T>() {
@Override
public void onResponse(T resp) {
Releasables.close(releasable);
listener.onResponse(resp);
}
@Override
public void onFailure(Exception exc) {
processFailure(context, exc);
Releasables.close(releasable);
listener.onFailure(exc);
}
};
}
private boolean isScrollContext(ReaderContext context) {
return context instanceof LegacyReaderContext && context.singleSession() == false;
}
private void processFailure(ReaderContext context, Exception exc) {
if (context.singleSession() || isScrollContext(context)) {
// we release the reader on failure if the request is a normal search or a scroll // we release the reader on failure if the request is a normal search or a scroll
freeReaderContext(context.id()); freeReaderContext(context.id());
} }
try { try {
if (Lucene.isCorruptionException(e)) { if (Lucene.isCorruptionException(exc)) {
context.indexShard().failShard("search execution corruption failure", e); context.indexShard().failShard("search execution corruption failure", exc);
} }
} catch (Exception inner) { } catch (Exception inner) {
inner.addSuppressed(e); inner.addSuppressed(exc);
logger.warn("failed to process shard failure to (potentially) send back shard failure on corruption", inner); logger.warn("failed to process shard failure to (potentially) send back shard failure on corruption", inner);
} }
} }
@ -1145,42 +1159,42 @@ public class SearchService extends AbstractLifecycleComponent implements IndexEv
*/ */
public CanMatchResponse canMatch(ShardSearchRequest request) throws IOException { public CanMatchResponse canMatch(ShardSearchRequest request) throws IOException {
assert request.searchType() == SearchType.QUERY_THEN_FETCH : "unexpected search type: " + request.searchType(); assert request.searchType() == SearchType.QUERY_THEN_FETCH : "unexpected search type: " + request.searchType();
final ReaderContext readerContext = request.readerId() != null ? getReaderContext(request.readerId()) : null; final ReaderContext readerContext = request.readerId() != null ? findReaderContext(request.readerId(), request) : null;
final Releasable markAsUsed = readerContext != null ? readerContext.markAsUsed() : null; try (Releasable ignored = readerContext != null ? readerContext.markAsUsed() : () -> {}) {
final IndexService indexService; final IndexService indexService;
final Engine.Searcher canMatchSearcher; final Engine.Searcher canMatchSearcher;
final boolean hasRefreshPending; final boolean hasRefreshPending;
if (readerContext != null) { if (readerContext != null) {
readerContext.indexShard().getSearchOperationListener().validateSearchContext(readerContext, request); checkKeepAliveLimit(request.keepAlive().millis());
checkKeepAliveLimit(request.keepAlive().millis()); readerContext.keepAlive(request.keepAlive().millis());
readerContext.keepAlive(request.keepAlive().millis()); indexService = readerContext.indexService();
indexService = readerContext.indexService(); canMatchSearcher = readerContext.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE);
canMatchSearcher = readerContext.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE); hasRefreshPending = false;
hasRefreshPending = false;
} else {
indexService = indicesService.indexServiceSafe(request.shardId().getIndex());
IndexShard indexShard = indexService.getShard(request.shardId().getId());
hasRefreshPending = indexShard.hasRefreshPending();
canMatchSearcher = indexShard.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE);
}
try (Releasable ignored = Releasables.wrap(markAsUsed, canMatchSearcher)) {
QueryShardContext context = indexService.newQueryShardContext(request.shardId().id(), canMatchSearcher,
request::nowInMillis, request.getClusterAlias());
Rewriteable.rewrite(request.getRewriteable(), context, false);
final boolean aliasFilterCanMatch = request.getAliasFilter()
.getQueryBuilder() instanceof MatchNoneQueryBuilder == false;
FieldSortBuilder sortBuilder = FieldSortBuilder.getPrimaryFieldSortOrNull(request.source());
MinAndMax<?> minMax = sortBuilder != null ? FieldSortBuilder.getMinMaxOrNull(context, sortBuilder) : null;
final boolean canMatch;
if (canRewriteToMatchNone(request.source())) {
QueryBuilder queryBuilder = request.source().query();
canMatch = aliasFilterCanMatch && queryBuilder instanceof MatchNoneQueryBuilder == false;
} else { } else {
// null query means match_all indexService = indicesService.indexServiceSafe(request.shardId().getIndex());
canMatch = aliasFilterCanMatch; IndexShard indexShard = indexService.getShard(request.shardId().getId());
hasRefreshPending = indexShard.hasRefreshPending();
canMatchSearcher = indexShard.acquireSearcher(Engine.CAN_MATCH_SEARCH_SOURCE);
}
try (Releasable ignored2 = canMatchSearcher) {
QueryShardContext context = indexService.newQueryShardContext(request.shardId().id(), canMatchSearcher,
request::nowInMillis, request.getClusterAlias());
Rewriteable.rewrite(request.getRewriteable(), context, false);
final boolean aliasFilterCanMatch = request.getAliasFilter()
.getQueryBuilder() instanceof MatchNoneQueryBuilder == false;
FieldSortBuilder sortBuilder = FieldSortBuilder.getPrimaryFieldSortOrNull(request.source());
MinAndMax<?> minMax = sortBuilder != null ? FieldSortBuilder.getMinMaxOrNull(context, sortBuilder) : null;
final boolean canMatch;
if (canRewriteToMatchNone(request.source())) {
QueryBuilder queryBuilder = request.source().query();
canMatch = aliasFilterCanMatch && queryBuilder instanceof MatchNoneQueryBuilder == false;
} else {
// null query means match_all
canMatch = aliasFilterCanMatch;
}
return new CanMatchResponse(canMatch || hasRefreshPending, minMax);
} }
return new CanMatchResponse(canMatch || hasRefreshPending, minMax);
} }
} }

View File

@ -19,8 +19,6 @@
package org.elasticsearch.search.internal; package org.elasticsearch.search.internal;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.index.IndexService; import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.engine.Engine; import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShard;
@ -35,8 +33,7 @@ public class LegacyReaderContext extends ReaderContext {
private AggregatedDfs aggregatedDfs; private AggregatedDfs aggregatedDfs;
private RescoreDocIds rescoreDocIds; private RescoreDocIds rescoreDocIds;
private Engine.Searcher searcher; private volatile Engine.Searcher searcher;
private Releasable onClose;
public LegacyReaderContext(long id, IndexService indexService, IndexShard indexShard, Engine.SearcherSupplier reader, public LegacyReaderContext(long id, IndexService indexService, IndexShard indexShard, Engine.SearcherSupplier reader,
ShardSearchRequest shardSearchRequest, long keepAliveInMillis) { ShardSearchRequest shardSearchRequest, long keepAliveInMillis) {
@ -59,8 +56,9 @@ public class LegacyReaderContext extends ReaderContext {
// This ensures that we wrap the searcher's reader with the user's permissions // This ensures that we wrap the searcher's reader with the user's permissions
// when they are available. // when they are available.
if (searcher == null) { if (searcher == null) {
Engine.Searcher delegate = searcherSupplier.acquireSearcher(source); final Engine.Searcher delegate = searcherSupplier.acquireSearcher(source);
onClose = delegate::close; addOnClose(delegate);
// wrap the searcher so that closing is a noop, the actual closing happens when this context is closed
searcher = new Engine.Searcher(delegate.source(), delegate.getDirectoryReader(), searcher = new Engine.Searcher(delegate.source(), delegate.getDirectoryReader(),
delegate.getSimilarity(), delegate.getQueryCache(), delegate.getQueryCachingPolicy(), () -> {}); delegate.getSimilarity(), delegate.getQueryCache(), delegate.getQueryCachingPolicy(), () -> {});
} }
@ -69,12 +67,6 @@ public class LegacyReaderContext extends ReaderContext {
return super.acquireSearcher(source); return super.acquireSearcher(source);
} }
@Override
void doClose() {
Releasables.close(onClose, super::doClose);
}
@Override @Override
public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) { public ShardSearchRequest getShardSearchRequest(ShardSearchRequest other) {
return shardSearchRequest; return shardSearchRequest;

View File

@ -28,6 +28,7 @@ import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.search.RescoreDocIds; import org.elasticsearch.search.RescoreDocIds;
import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.transport.TransportRequest;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@ -84,6 +85,10 @@ public class ReaderContext implements Releasable {
}; };
} }
public void validate(TransportRequest request) {
indexShard.getSearchOperationListener().validateReaderContext(this, request);
}
private long nowInMillis() { private long nowInMillis() {
return indexShard.getThreadPool().relativeTimeInMillis(); return indexShard.getThreadPool().relativeTimeInMillis();
} }

View File

@ -116,7 +116,7 @@ public class SearchOperationListenerTests extends ESTestCase {
} }
@Override @Override
public void validateSearchContext(ReaderContext readerContext, TransportRequest request) { public void validateReaderContext(ReaderContext readerContext, TransportRequest request) {
assertNotNull(readerContext); assertNotNull(readerContext);
validateSearchContext.incrementAndGet(); validateSearchContext.incrementAndGet();
} }
@ -271,10 +271,10 @@ public class SearchOperationListenerTests extends ESTestCase {
assertEquals(0, validateSearchContext.get()); assertEquals(0, validateSearchContext.get());
if (throwingListeners == 0) { if (throwingListeners == 0) {
compositeListener.validateSearchContext(mock(ReaderContext.class), Empty.INSTANCE); compositeListener.validateReaderContext(mock(ReaderContext.class), Empty.INSTANCE);
} else { } else {
RuntimeException expected = expectThrows(RuntimeException.class, RuntimeException expected = expectThrows(RuntimeException.class,
() -> compositeListener.validateSearchContext(mock(ReaderContext.class), Empty.INSTANCE)); () -> compositeListener.validateReaderContext(mock(ReaderContext.class), Empty.INSTANCE));
assertNull(expected.getMessage()); assertNull(expected.getMessage());
assertEquals(throwingListeners - 1, expected.getSuppressed().length); assertEquals(throwingListeners - 1, expected.getSuppressed().length);
if (throwingListeners > 1) { if (throwingListeners > 1) {

View File

@ -79,7 +79,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
SecuritySearchOperationListener listener = SecuritySearchOperationListener listener =
new SecuritySearchOperationListener(securityContext, licenseState, auditTrailService); new SecuritySearchOperationListener(securityContext, licenseState, auditTrailService);
listener.onNewScrollContext(readerContext); listener.onNewScrollContext(readerContext);
listener.validateSearchContext(readerContext, Empty.INSTANCE); listener.validateReaderContext(readerContext, Empty.INSTANCE);
verify(licenseState, times(2)).isSecurityEnabled(); verify(licenseState, times(2)).isSecurityEnabled();
verifyZeroInteractions(auditTrailService, searchContext); verifyZeroInteractions(auditTrailService, searchContext);
} }
@ -136,7 +136,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
try (StoredContext ignore = threadContext.newStoredContext(false)) { try (StoredContext ignore = threadContext.newStoredContext(false)) {
Authentication authentication = new Authentication(new User("test", "role"), new RealmRef("realm", "file", "node"), null); Authentication authentication = new Authentication(new User("test", "role"), new RealmRef("realm", "file", "node"), null);
authentication.writeToContext(threadContext); authentication.writeToContext(threadContext);
listener.validateSearchContext(readerContext, Empty.INSTANCE); listener.validateReaderContext(readerContext, Empty.INSTANCE);
assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl)); assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl));
verify(licenseState).isSecurityEnabled(); verify(licenseState).isSecurityEnabled();
verifyZeroInteractions(auditTrail); verifyZeroInteractions(auditTrail);
@ -148,7 +148,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
Authentication authentication = Authentication authentication =
new Authentication(new User("test", "role"), new RealmRef(realmName, "file", nodeName), null); new Authentication(new User("test", "role"), new RealmRef(realmName, "file", nodeName), null);
authentication.writeToContext(threadContext); authentication.writeToContext(threadContext);
listener.validateSearchContext(readerContext, Empty.INSTANCE); listener.validateReaderContext(readerContext, Empty.INSTANCE);
assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl)); assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl));
verify(licenseState, times(2)).isSecurityEnabled(); verify(licenseState, times(2)).isSecurityEnabled();
verifyZeroInteractions(auditTrail); verifyZeroInteractions(auditTrail);
@ -166,7 +166,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
(AuthorizationInfo) () -> Collections.singletonMap(PRINCIPAL_ROLES_FIELD_NAME, authentication.getUser().roles())); (AuthorizationInfo) () -> Collections.singletonMap(PRINCIPAL_ROLES_FIELD_NAME, authentication.getUser().roles()));
final InternalScrollSearchRequest request = new InternalScrollSearchRequest(); final InternalScrollSearchRequest request = new InternalScrollSearchRequest();
SearchContextMissingException expected = expectThrows(SearchContextMissingException.class, SearchContextMissingException expected = expectThrows(SearchContextMissingException.class,
() -> listener.validateSearchContext(readerContext, request)); () -> listener.validateReaderContext(readerContext, request));
assertEquals(readerContext.id(), expected.contextId()); assertEquals(readerContext.id(), expected.contextId());
assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), nullValue()); assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), nullValue());
verify(licenseState, Mockito.atLeast(3)).isSecurityEnabled(); verify(licenseState, Mockito.atLeast(3)).isSecurityEnabled();
@ -185,7 +185,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
authentication.writeToContext(threadContext); authentication.writeToContext(threadContext);
threadContext.putTransient(ORIGINATING_ACTION_KEY, "action"); threadContext.putTransient(ORIGINATING_ACTION_KEY, "action");
final InternalScrollSearchRequest request = new InternalScrollSearchRequest(); final InternalScrollSearchRequest request = new InternalScrollSearchRequest();
listener.validateSearchContext(readerContext, request); listener.validateReaderContext(readerContext, request);
assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl)); assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), is(indicesAccessControl));
verify(licenseState, Mockito.atLeast(4)).isSecurityEnabled(); verify(licenseState, Mockito.atLeast(4)).isSecurityEnabled();
verifyNoMoreInteractions(auditTrail); verifyNoMoreInteractions(auditTrail);
@ -204,7 +204,7 @@ public class SecuritySearchOperationListenerTests extends ESSingleNodeTestCase {
(AuthorizationInfo) () -> Collections.singletonMap(PRINCIPAL_ROLES_FIELD_NAME, authentication.getUser().roles())); (AuthorizationInfo) () -> Collections.singletonMap(PRINCIPAL_ROLES_FIELD_NAME, authentication.getUser().roles()));
final InternalScrollSearchRequest request = new InternalScrollSearchRequest(); final InternalScrollSearchRequest request = new InternalScrollSearchRequest();
SearchContextMissingException expected = expectThrows(SearchContextMissingException.class, SearchContextMissingException expected = expectThrows(SearchContextMissingException.class,
() -> listener.validateSearchContext(readerContext, request)); () -> listener.validateReaderContext(readerContext, request));
assertEquals(readerContext.id(), expected.contextId()); assertEquals(readerContext.id(), expected.contextId());
assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), nullValue()); assertThat(threadContext.getTransient(AuthorizationServiceField.INDICES_PERMISSIONS_KEY), nullValue());
verify(licenseState, Mockito.atLeast(5)).isSecurityEnabled(); verify(licenseState, Mockito.atLeast(5)).isSecurityEnabled();

View File

@ -69,7 +69,7 @@ public final class SecuritySearchOperationListener implements SearchOperationLis
* object from the scroll context with the current authentication context * object from the scroll context with the current authentication context
*/ */
@Override @Override
public void validateSearchContext(ReaderContext readerContext, TransportRequest request) { public void validateReaderContext(ReaderContext readerContext, TransportRequest request) {
if (licenseState.isSecurityEnabled()) { if (licenseState.isSecurityEnabled()) {
if (readerContext.scrollContext() != null) { if (readerContext.scrollContext() != null) {
final Authentication originalAuth = readerContext.getFromContext(AuthenticationField.AUTHENTICATION_KEY); final Authentication originalAuth = readerContext.getFromContext(AuthenticationField.AUTHENTICATION_KEY);