inner_hits: Reuse inner hit query weight

Previously query weight was created for each search hit that needed to compute inner hits,
 with this change the weight of the inner hit query is computed once for all search hits.

Closes #23917
This commit is contained in:
Martijn van Groningen 2017-05-09 10:46:49 +02:00
parent e13db1b269
commit e5b42bed50
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
2 changed files with 184 additions and 99 deletions

View File

@ -19,19 +19,31 @@
package org.elasticsearch.search.fetch.subphase; package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DocValuesTermsQuery; import org.apache.lucene.search.DocValuesTermsQuery;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery; import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentMapper;
@ -40,11 +52,11 @@ import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.index.mapper.ParentFieldMapper; import org.elasticsearch.index.mapper.ParentFieldMapper;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField; import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.SubSearchContext; import org.elasticsearch.search.internal.SubSearchContext;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -57,7 +69,7 @@ public final class InnerHitsContext {
this.innerHits = new HashMap<>(); this.innerHits = new HashMap<>();
} }
public InnerHitsContext(Map<String, BaseInnerHits> innerHits) { InnerHitsContext(Map<String, BaseInnerHits> innerHits) {
this.innerHits = Objects.requireNonNull(innerHits); this.innerHits = Objects.requireNonNull(innerHits);
} }
@ -77,14 +89,16 @@ public final class InnerHitsContext {
public abstract static class BaseInnerHits extends SubSearchContext { public abstract static class BaseInnerHits extends SubSearchContext {
private final String name; private final String name;
final SearchContext context;
private InnerHitsContext childInnerHits; private InnerHitsContext childInnerHits;
protected BaseInnerHits(String name, SearchContext context) { BaseInnerHits(String name, SearchContext context) {
super(context); super(context);
this.name = name; this.name = name;
this.context = context;
} }
public abstract TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException; public abstract TopDocs[] topDocs(SearchHit[] hits) throws IOException;
public String getName() { public String getName() {
return name; return name;
@ -98,6 +112,12 @@ public final class InnerHitsContext {
public void setChildInnerHits(Map<String, InnerHitsContext.BaseInnerHits> childInnerHits) { public void setChildInnerHits(Map<String, InnerHitsContext.BaseInnerHits> childInnerHits) {
this.childInnerHits = new InnerHitsContext(childInnerHits); this.childInnerHits = new InnerHitsContext(childInnerHits);
} }
Weight createInnerHitQueryWeight() throws IOException {
final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
return context.searcher().createNormalizedWeight(query(), needsScores);
}
} }
public static final class NestedInnerHits extends BaseInnerHits { public static final class NestedInnerHits extends BaseInnerHits {
@ -112,35 +132,48 @@ public final class InnerHitsContext {
} }
@Override @Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException { public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
Query rawParentFilter; Weight innerHitQueryWeight = createInnerHitQueryWeight();
if (parentObjectMapper == null) { TopDocs[] result = new TopDocs[hits.length];
rawParentFilter = Queries.newNonNestedFilter(); for (int i = 0; i < hits.length; i++) {
} else { SearchHit hit = hits[i];
rawParentFilter = parentObjectMapper.nestedTypeFilter(); Query rawParentFilter;
} if (parentObjectMapper == null) {
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter); rawParentFilter = Queries.newNonNestedFilter();
Query childFilter = childObjectMapper.nestedTypeFilter();
int parentDocId = hitContext.readerContext().docBase + hitContext.docId();
Query q = Queries.filtered(query(), new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId));
if (size() == 0) {
return new TopDocs(context.searcher().count(q), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector<?> topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else { } else {
topDocsCollector = TopScoreDocCollector.create(topN); rawParentFilter = parentObjectMapper.nestedTypeFilter();
} }
try {
context.searcher().search(q, topDocsCollector); int parentDocId = hit.docId();
} finally { final int readerIndex = ReaderUtil.subIndex(parentDocId, searcher().getIndexReader().leaves());
clearReleasables(Lifetime.COLLECTION); // With nested inner hits the nested docs are always in the same segement, so need to use the other segments
LeafReaderContext ctx = searcher().getIndexReader().leaves().get(readerIndex);
Query childFilter = childObjectMapper.nestedTypeFilter();
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
Query q = new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId);
Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) {
TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector<?> topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
}
try {
intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
} finally {
clearReleasables(Lifetime.COLLECTION);
}
result[i] = topDocsCollector.topDocs(from(), size());
} }
return topDocsCollector.topDocs(from(), size());
} }
return result;
} }
} }
@ -156,53 +189,65 @@ public final class InnerHitsContext {
} }
@Override @Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException { public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
final Query hitQuery; Weight innerHitQueryWeight = createInnerHitQueryWeight();
if (isParentHit(hitContext.hit())) { TopDocs[] result = new TopDocs[hits.length];
String field = ParentFieldMapper.joinField(hitContext.hit().getType()); for (int i = 0; i < hits.length; i++) {
hitQuery = new DocValuesTermsQuery(field, hitContext.hit().getId()); SearchHit hit = hits[i];
} else if (isChildHit(hitContext.hit())) { final Query hitQuery;
DocumentMapper hitDocumentMapper = mapperService.documentMapper(hitContext.hit().getType()); if (isParentHit(hit)) {
final String parentType = hitDocumentMapper.parentFieldMapper().type(); String field = ParentFieldMapper.joinField(hit.getType());
SearchHitField parentField = hitContext.hit().field(ParentFieldMapper.NAME); hitQuery = new DocValuesTermsQuery(field, hit.getId());
if (parentField == null) { } else if (isChildHit(hit)) {
throw new IllegalStateException("All children must have a _parent"); DocumentMapper hitDocumentMapper = mapperService.documentMapper(hit.getType());
} final String parentType = hitDocumentMapper.parentFieldMapper().type();
Term uidTerm = context.mapperService().createUidTerm(parentType, parentField.getValue()); SearchHitField parentField = hit.field(ParentFieldMapper.NAME);
if (uidTerm == null) { if (parentField == null) {
hitQuery = new MatchNoDocsQuery("Missing type: " + parentType); throw new IllegalStateException("All children must have a _parent");
}
Term uidTerm = context.mapperService().createUidTerm(parentType, parentField.getValue());
if (uidTerm == null) {
hitQuery = new MatchNoDocsQuery("Missing type: " + parentType);
} else {
hitQuery = new TermQuery(uidTerm);
}
} else { } else {
hitQuery = new TermQuery(uidTerm); result[i] = Lucene.EMPTY_TOP_DOCS;
continue;
} }
} else {
return Lucene.EMPTY_TOP_DOCS;
}
BooleanQuery q = new BooleanQuery.Builder() BooleanQuery q = new BooleanQuery.Builder()
.add(query(), Occur.MUST) // Only include docs that have the current hit as parent
// Only include docs that have the current hit as parent .add(hitQuery, Occur.FILTER)
.add(hitQuery, Occur.FILTER) // Only include docs that have this inner hits type
// Only include docs that have this inner hits type .add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER)
.add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER) .build();
.build(); Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) { if (size() == 0) {
final int count = context.searcher().count(q); TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
return new TopDocs(count, Lucene.EMPTY_SCORE_DOCS, 0); for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
} else { intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc()); }
TopDocsCollector topDocsCollector; result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else { } else {
topDocsCollector = TopScoreDocCollector.create(topN); int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector topDocsCollector;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, true, trackScores(), trackScores());
} else {
topDocsCollector = TopScoreDocCollector.create(topN);
}
try {
for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
}
} finally {
clearReleasables(Lifetime.COLLECTION);
}
result[i] = topDocsCollector.topDocs(from(), size());
} }
try {
context.searcher().search(q, topDocsCollector);
} finally {
clearReleasables(Lifetime.COLLECTION);
}
return topDocsCollector.topDocs(from(), size());
} }
return result;
} }
private boolean isParentHit(SearchHit hit) { private boolean isParentHit(SearchHit hit) {
@ -214,4 +259,42 @@ public final class InnerHitsContext {
return documentMapper.type().equals(hitDocumentMapper.parentFieldMapper().type()); return documentMapper.type().equals(hitDocumentMapper.parentFieldMapper().type());
} }
} }
static void intersect(Weight weight, Weight innerHitQueryWeight, Collector collector, LeafReaderContext ctx) throws IOException {
ScorerSupplier scorerSupplier = weight.scorerSupplier(ctx);
if (scorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer scorer = scorerSupplier.get(true);
ScorerSupplier innerHitQueryScorerSupplier = innerHitQueryWeight.scorerSupplier(ctx);
if (innerHitQueryScorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer innerHitQueryScorer = innerHitQueryScorerSupplier.get(true);
final LeafCollector leafCollector;
try {
leafCollector = collector.getLeafCollector(ctx);
// Just setting the innerHitQueryScorer is ok, because that is the actual scoring part of the query
leafCollector.setScorer(innerHitQueryScorer);
} catch (CollectionTerminatedException e) {
return;
}
try {
Bits acceptDocs = ctx.reader().getLiveDocs();
DocIdSetIterator iterator = ConjunctionDISI.intersectIterators(Arrays.asList(innerHitQueryScorer.iterator(),
scorer.iterator()));
for (int docId = iterator.nextDoc(); docId < DocIdSetIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(docId)) {
leafCollector.collect(docId);
}
}
} catch (CollectionTerminatedException e) {
// ignore and continue
}
}
} }

View File

@ -22,12 +22,11 @@ package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.fetch.FetchPhase; import org.elasticsearch.search.fetch.FetchPhase;
import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
@ -43,39 +42,42 @@ public final class InnerHitsFetchSubPhase implements FetchSubPhase {
} }
@Override @Override
public void hitExecute(SearchContext context, HitContext hitContext) { public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) { if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) {
return; return;
} }
Map<String, SearchHits> results = new HashMap<>();
for (Map.Entry<String, InnerHitsContext.BaseInnerHits> entry : context.innerHits().getInnerHits().entrySet()) { for (Map.Entry<String, InnerHitsContext.BaseInnerHits> entry : context.innerHits().getInnerHits().entrySet()) {
InnerHitsContext.BaseInnerHits innerHits = entry.getValue(); InnerHitsContext.BaseInnerHits innerHits = entry.getValue();
TopDocs topDocs; TopDocs[] topDocs = innerHits.topDocs(hits);
try { for (int i = 0; i < hits.length; i++) {
topDocs = innerHits.topDocs(context, hitContext); SearchHit hit = hits[i];
} catch (IOException e) { TopDocs topDoc = topDocs[i];
throw ExceptionsHelper.convertToElastic(e);
} Map<String, SearchHits> results = hit.getInnerHits();
innerHits.queryResult().topDocs(topDocs, innerHits.sort() == null ? null : innerHits.sort().formats); if (results == null) {
int[] docIdsToLoad = new int[topDocs.scoreDocs.length]; hit.setInnerHits(results = new HashMap<>());
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
}
innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length);
fetchPhase.execute(innerHits);
FetchSearchResult fetchResult = innerHits.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits();
for (int i = 0; i < internalHits.length; i++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i];
SearchHit searchHitFields = internalHits[i];
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
searchHitFields.sortValues(fieldDoc.fields, innerHits.sort().formats);
} }
innerHits.queryResult().topDocs(topDoc, innerHits.sort() == null ? null : innerHits.sort().formats);
int[] docIdsToLoad = new int[topDoc.scoreDocs.length];
for (int j = 0; j < topDoc.scoreDocs.length; j++) {
docIdsToLoad[j] = topDoc.scoreDocs[j].doc;
}
innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length);
fetchPhase.execute(innerHits);
FetchSearchResult fetchResult = innerHits.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits();
for (int j = 0; j < internalHits.length; j++) {
ScoreDoc scoreDoc = topDoc.scoreDocs[j];
SearchHit searchHitFields = internalHits[j];
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
searchHitFields.sortValues(fieldDoc.fields, innerHits.sort().formats);
}
}
results.put(entry.getKey(), fetchResult.hits());
} }
results.put(entry.getKey(), fetchResult.hits());
} }
hitContext.hit().setInnerHits(results);
} }
} }