inner_hits: Reuse inner hit query weight

Previously query weight was created for each search hit that needed to compute inner hits,
 with this change the weight of the inner hit query is computed once for all search hits.

Closes #23917
This commit is contained in:
Martijn van Groningen 2017-05-09 10:46:49 +02:00
parent e13db1b269
commit e5b42bed50
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
2 changed files with 184 additions and 99 deletions

View File

@ -19,19 +19,31 @@
package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DocValuesTermsQuery;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.mapper.DocumentMapper;
@ -40,11 +52,11 @@ import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.index.mapper.ParentFieldMapper;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.SubSearchContext;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@ -57,7 +69,7 @@ public final class InnerHitsContext {
this.innerHits = new HashMap<>();
}
public InnerHitsContext(Map<String, BaseInnerHits> innerHits) {
InnerHitsContext(Map<String, BaseInnerHits> innerHits) {
this.innerHits = Objects.requireNonNull(innerHits);
}
@ -77,14 +89,16 @@ public final class InnerHitsContext {
public abstract static class BaseInnerHits extends SubSearchContext {
private final String name;
final SearchContext context;
private InnerHitsContext childInnerHits;
protected BaseInnerHits(String name, SearchContext context) {
BaseInnerHits(String name, SearchContext context) {
super(context);
this.name = name;
this.context = context;
}
public abstract TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException;
public abstract TopDocs[] topDocs(SearchHit[] hits) throws IOException;
public String getName() {
return name;
@ -98,6 +112,12 @@ public final class InnerHitsContext {
public void setChildInnerHits(Map<String, InnerHitsContext.BaseInnerHits> childInnerHits) {
this.childInnerHits = new InnerHitsContext(childInnerHits);
}
Weight createInnerHitQueryWeight() throws IOException {
final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
return context.searcher().createNormalizedWeight(query(), needsScores);
}
}
public static final class NestedInnerHits extends BaseInnerHits {
@ -112,35 +132,48 @@ public final class InnerHitsContext {
}
@Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException {
Query rawParentFilter;
if (parentObjectMapper == null) {
rawParentFilter = Queries.newNonNestedFilter();
} else {
rawParentFilter = parentObjectMapper.nestedTypeFilter();
}
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
Query childFilter = childObjectMapper.nestedTypeFilter();
int parentDocId = hitContext.readerContext().docBase + hitContext.docId();
Query q = Queries.filtered(query(), new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId));
if (size() == 0) {
return new TopDocs(context.searcher().count(q), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector<?> topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
Weight innerHitQueryWeight = createInnerHitQueryWeight();
TopDocs[] result = new TopDocs[hits.length];
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Query rawParentFilter;
if (parentObjectMapper == null) {
rawParentFilter = Queries.newNonNestedFilter();
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
rawParentFilter = parentObjectMapper.nestedTypeFilter();
}
try {
context.searcher().search(q, topDocsCollector);
} finally {
clearReleasables(Lifetime.COLLECTION);
int parentDocId = hit.docId();
final int readerIndex = ReaderUtil.subIndex(parentDocId, searcher().getIndexReader().leaves());
// With nested inner hits the nested docs are always in the same segement, so need to use the other segments
LeafReaderContext ctx = searcher().getIndexReader().leaves().get(readerIndex);
Query childFilter = childObjectMapper.nestedTypeFilter();
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
Query q = new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId);
Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) {
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector<?> topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
}
try {
intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
} finally {
clearReleasables(Lifetime.COLLECTION);
}
result[i] = topDocsCollector.topDocs(from(), size());
}
return topDocsCollector.topDocs(from(), size());
}
return result;
}
}
@ -156,53 +189,65 @@ public final class InnerHitsContext {
}
@Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException {
final Query hitQuery;
if (isParentHit(hitContext.hit())) {
String field = ParentFieldMapper.joinField(hitContext.hit().getType());
hitQuery = new DocValuesTermsQuery(field, hitContext.hit().getId());
} else if (isChildHit(hitContext.hit())) {
DocumentMapper hitDocumentMapper = mapperService.documentMapper(hitContext.hit().getType());
final String parentType = hitDocumentMapper.parentFieldMapper().type();
SearchHitField parentField = hitContext.hit().field(ParentFieldMapper.NAME);
if (parentField == null) {
throw new IllegalStateException("All children must have a _parent");
}
Term uidTerm = context.mapperService().createUidTerm(parentType, parentField.getValue());
if (uidTerm == null) {
hitQuery = new MatchNoDocsQuery("Missing type: " + parentType);
public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
Weight innerHitQueryWeight = createInnerHitQueryWeight();
TopDocs[] result = new TopDocs[hits.length];
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
final Query hitQuery;
if (isParentHit(hit)) {
String field = ParentFieldMapper.joinField(hit.getType());
hitQuery = new DocValuesTermsQuery(field, hit.getId());
} else if (isChildHit(hit)) {
DocumentMapper hitDocumentMapper = mapperService.documentMapper(hit.getType());
final String parentType = hitDocumentMapper.parentFieldMapper().type();
SearchHitField parentField = hit.field(ParentFieldMapper.NAME);
if (parentField == null) {
throw new IllegalStateException("All children must have a _parent");
}
Term uidTerm = context.mapperService().createUidTerm(parentType, parentField.getValue());
if (uidTerm == null) {
hitQuery = new MatchNoDocsQuery("Missing type: " + parentType);
} else {
hitQuery = new TermQuery(uidTerm);
}
} else {
hitQuery = new TermQuery(uidTerm);
result[i] = Lucene.EMPTY_TOP_DOCS;
continue;
}
} else {
return Lucene.EMPTY_TOP_DOCS;
}
BooleanQuery q = new BooleanQuery.Builder()
.add(query(), Occur.MUST)
// Only include docs that have the current hit as parent
.add(hitQuery, Occur.FILTER)
// Only include docs that have this inner hits type
.add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER)
.build();
if (size() == 0) {
final int count = context.searcher().count(q);
return new TopDocs(count, Lucene.EMPTY_SCORE_DOCS, 0);
} else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
BooleanQuery q = new BooleanQuery.Builder()
// Only include docs that have the current hit as parent
.add(hitQuery, Occur.FILTER)
// Only include docs that have this inner hits type
.add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER)
.build();
Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) {
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
}
result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
}
try {
for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
}
} finally {
clearReleasables(Lifetime.COLLECTION);
}
result[i] = topDocsCollector.topDocs(from(), size());
}
try {
context.searcher().search(q, topDocsCollector);
} finally {
clearReleasables(Lifetime.COLLECTION);
}
return topDocsCollector.topDocs(from(), size());
}
return result;
}
private boolean isParentHit(SearchHit hit) {
@ -214,4 +259,42 @@ public final class InnerHitsContext {
return documentMapper.type().equals(hitDocumentMapper.parentFieldMapper().type());
}
}
static void intersect(Weight weight, Weight innerHitQueryWeight, Collector collector, LeafReaderContext ctx) throws IOException {
ScorerSupplier scorerSupplier = weight.scorerSupplier(ctx);
if (scorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer scorer = scorerSupplier.get(true);
ScorerSupplier innerHitQueryScorerSupplier = innerHitQueryWeight.scorerSupplier(ctx);
if (innerHitQueryScorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer innerHitQueryScorer = innerHitQueryScorerSupplier.get(true);
final LeafCollector leafCollector;
try {
leafCollector = collector.getLeafCollector(ctx);
// Just setting the innerHitQueryScorer is ok, because that is the actual scoring part of the query
leafCollector.setScorer(innerHitQueryScorer);
} catch (CollectionTerminatedException e) {
return;
}
try {
Bits acceptDocs = ctx.reader().getLiveDocs();
DocIdSetIterator iterator = ConjunctionDISI.intersectIterators(Arrays.asList(innerHitQueryScorer.iterator(),
scorer.iterator()));
for (int docId = iterator.nextDoc(); docId < DocIdSetIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(docId)) {
leafCollector.collect(docId);
}
}
} catch (CollectionTerminatedException e) {
// ignore and continue
}
}
}

View File

@ -22,12 +22,11 @@ package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.fetch.FetchPhase;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
@ -43,39 +42,42 @@ public final class InnerHitsFetchSubPhase implements FetchSubPhase {
}
@Override
public void hitExecute(SearchContext context, HitContext hitContext) {
public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) {
return;
}
Map<String, SearchHits> results = new HashMap<>();
for (Map.Entry<String, InnerHitsContext.BaseInnerHits> entry : context.innerHits().getInnerHits().entrySet()) {
InnerHitsContext.BaseInnerHits innerHits = entry.getValue();
TopDocs topDocs;
try {
topDocs = innerHits.topDocs(context, hitContext);
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
innerHits.queryResult().topDocs(topDocs, innerHits.sort() == null ? null : innerHits.sort().formats);
int[] docIdsToLoad = new int[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
}
innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length);
fetchPhase.execute(innerHits);
FetchSearchResult fetchResult = innerHits.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits();
for (int i = 0; i < internalHits.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
SearchHit searchHitFields = internalHits[i];
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
searchHitFields.sortValues(fieldDoc.fields, innerHits.sort().formats);
TopDocs[] topDocs = innerHits.topDocs(hits);
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
TopDocs topDoc = topDocs[i];
Map<String, SearchHits> results = hit.getInnerHits();
if (results == null) {
hit.setInnerHits(results = new HashMap<>());
}
innerHits.queryResult().topDocs(topDoc, innerHits.sort() == null ? null : innerHits.sort().formats);
int[] docIdsToLoad = new int[topDoc.scoreDocs.length];
for (int j = 0; j < topDoc.scoreDocs.length; j++) {
docIdsToLoad[j] = topDoc.scoreDocs[j].doc;
}
innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length);
fetchPhase.execute(innerHits);
FetchSearchResult fetchResult = innerHits.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits();
for (int j = 0; j < internalHits.length; j++) {
ScoreDoc scoreDoc = topDoc.scoreDocs[j];
SearchHit searchHitFields = internalHits[j];
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
searchHitFields.sortValues(fieldDoc.fields, innerHits.sort().formats);
}
}
results.put(entry.getKey(), fetchResult.hits());
}
results.put(entry.getKey(), fetchResult.hits());
}
hitContext.hit().setInnerHits(results);
}
}