inner_hits: Reuse inner hit query weight

Previously query weight was created for each search hit that needed to compute inner hits,
 with this change the weight of the inner hit query is computed once for all search hits.

Closes #23917
This commit is contained in:
Martijn van Groningen 2017-05-09 10:46:49 +02:00
parent e13db1b269
commit e5b42bed50
No known key found for this signature in database
GPG Key ID: AB236F4FCF2AF12A
2 changed files with 184 additions and 99 deletions

View File

@ -19,19 +19,31 @@
package org.elasticsearch.search.fetch.subphase; package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DocValuesTermsQuery; import org.apache.lucene.search.DocValuesTermsQuery;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector; import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector; import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery; import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery;
import org.apache.lucene.util.Bits;
import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.lucene.search.Queries; import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentMapper;
@ -40,11 +52,11 @@ import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.index.mapper.ParentFieldMapper; import org.elasticsearch.index.mapper.ParentFieldMapper;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField; import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.SubSearchContext; import org.elasticsearch.search.internal.SubSearchContext;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@ -57,7 +69,7 @@ public final class InnerHitsContext {
this.innerHits = new HashMap<>(); this.innerHits = new HashMap<>();
} }
public InnerHitsContext(Map<String, BaseInnerHits> innerHits) { InnerHitsContext(Map<String, BaseInnerHits> innerHits) {
this.innerHits = Objects.requireNonNull(innerHits); this.innerHits = Objects.requireNonNull(innerHits);
} }
@ -77,14 +89,16 @@ public final class InnerHitsContext {
public abstract static class BaseInnerHits extends SubSearchContext { public abstract static class BaseInnerHits extends SubSearchContext {
private final String name; private final String name;
final SearchContext context;
private InnerHitsContext childInnerHits; private InnerHitsContext childInnerHits;
protected BaseInnerHits(String name, SearchContext context) { BaseInnerHits(String name, SearchContext context) {
super(context); super(context);
this.name = name; this.name = name;
this.context = context;
} }
public abstract TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException; public abstract TopDocs[] topDocs(SearchHit[] hits) throws IOException;
public String getName() { public String getName() {
return name; return name;
@ -98,6 +112,12 @@ public final class InnerHitsContext {
public void setChildInnerHits(Map<String, InnerHitsContext.BaseInnerHits> childInnerHits) { public void setChildInnerHits(Map<String, InnerHitsContext.BaseInnerHits> childInnerHits) {
this.childInnerHits = new InnerHitsContext(childInnerHits); this.childInnerHits = new InnerHitsContext(childInnerHits);
} }
Weight createInnerHitQueryWeight() throws IOException {
final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
return context.searcher().createNormalizedWeight(query(), needsScores);
}
} }
public static final class NestedInnerHits extends BaseInnerHits { public static final class NestedInnerHits extends BaseInnerHits {
@ -112,20 +132,31 @@ public final class InnerHitsContext {
} }
@Override @Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException { public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
Weight innerHitQueryWeight = createInnerHitQueryWeight();
TopDocs[] result = new TopDocs[hits.length];
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
Query rawParentFilter; Query rawParentFilter;
if (parentObjectMapper == null) { if (parentObjectMapper == null) {
rawParentFilter = Queries.newNonNestedFilter(); rawParentFilter = Queries.newNonNestedFilter();
} else { } else {
rawParentFilter = parentObjectMapper.nestedTypeFilter(); rawParentFilter = parentObjectMapper.nestedTypeFilter();
} }
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
Query childFilter = childObjectMapper.nestedTypeFilter();
int parentDocId = hitContext.readerContext().docBase + hitContext.docId();
Query q = Queries.filtered(query(), new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId));
int parentDocId = hit.docId();
final int readerIndex = ReaderUtil.subIndex(parentDocId, searcher().getIndexReader().leaves());
// With nested inner hits the nested docs are always in the same segement, so need to use the other segments
LeafReaderContext ctx = searcher().getIndexReader().leaves().get(readerIndex);
Query childFilter = childObjectMapper.nestedTypeFilter();
BitSetProducer parentFilter = context.bitsetFilterCache().getBitSetProducer(rawParentFilter);
Query q = new ParentChildrenBlockJoinQuery(parentFilter, childFilter, parentDocId);
Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) { if (size() == 0) {
return new TopDocs(context.searcher().count(q), Lucene.EMPTY_SCORE_DOCS, 0); TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else { } else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc()); int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector<?> topDocsCollector; TopDocsCollector<?> topDocsCollector;
@ -135,13 +166,15 @@ public final class InnerHitsContext {
topDocsCollector = TopScoreDocCollector.create(topN); topDocsCollector = TopScoreDocCollector.create(topN);
} }
try { try {
context.searcher().search(q, topDocsCollector); intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
} finally { } finally {
clearReleasables(Lifetime.COLLECTION); clearReleasables(Lifetime.COLLECTION);
} }
return topDocsCollector.topDocs(from(), size()); result[i] = topDocsCollector.topDocs(from(), size());
} }
} }
return result;
}
} }
public static final class ParentChildInnerHits extends BaseInnerHits { public static final class ParentChildInnerHits extends BaseInnerHits {
@ -156,15 +189,19 @@ public final class InnerHitsContext {
} }
@Override @Override
public TopDocs topDocs(SearchContext context, FetchSubPhase.HitContext hitContext) throws IOException { public TopDocs[] topDocs(SearchHit[] hits) throws IOException {
Weight innerHitQueryWeight = createInnerHitQueryWeight();
TopDocs[] result = new TopDocs[hits.length];
for (int i = 0; i < hits.length; i++) {
SearchHit hit = hits[i];
final Query hitQuery; final Query hitQuery;
if (isParentHit(hitContext.hit())) { if (isParentHit(hit)) {
String field = ParentFieldMapper.joinField(hitContext.hit().getType()); String field = ParentFieldMapper.joinField(hit.getType());
hitQuery = new DocValuesTermsQuery(field, hitContext.hit().getId()); hitQuery = new DocValuesTermsQuery(field, hit.getId());
} else if (isChildHit(hitContext.hit())) { } else if (isChildHit(hit)) {
DocumentMapper hitDocumentMapper = mapperService.documentMapper(hitContext.hit().getType()); DocumentMapper hitDocumentMapper = mapperService.documentMapper(hit.getType());
final String parentType = hitDocumentMapper.parentFieldMapper().type(); final String parentType = hitDocumentMapper.parentFieldMapper().type();
SearchHitField parentField = hitContext.hit().field(ParentFieldMapper.NAME); SearchHitField parentField = hit.field(ParentFieldMapper.NAME);
if (parentField == null) { if (parentField == null) {
throw new IllegalStateException("All children must have a _parent"); throw new IllegalStateException("All children must have a _parent");
} }
@ -175,19 +212,23 @@ public final class InnerHitsContext {
hitQuery = new TermQuery(uidTerm); hitQuery = new TermQuery(uidTerm);
} }
} else { } else {
return Lucene.EMPTY_TOP_DOCS; result[i] = Lucene.EMPTY_TOP_DOCS;
continue;
} }
BooleanQuery q = new BooleanQuery.Builder() BooleanQuery q = new BooleanQuery.Builder()
.add(query(), Occur.MUST)
// Only include docs that have the current hit as parent // Only include docs that have the current hit as parent
.add(hitQuery, Occur.FILTER) .add(hitQuery, Occur.FILTER)
// Only include docs that have this inner hits type // Only include docs that have this inner hits type
.add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER) .add(documentMapper.typeFilter(context.getQueryShardContext()), Occur.FILTER)
.build(); .build();
Weight weight = context.searcher().createNormalizedWeight(q, false);
if (size() == 0) { if (size() == 0) {
final int count = context.searcher().count(q); TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
return new TopDocs(count, Lucene.EMPTY_SCORE_DOCS, 0); for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
intersect(weight, innerHitQueryWeight, totalHitCountCollector, ctx);
}
result[i] = new TopDocs(totalHitCountCollector.getTotalHits(), Lucene.EMPTY_SCORE_DOCS, 0);
} else { } else {
int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc()); int topN = Math.min(from() + size(), context.searcher().getIndexReader().maxDoc());
TopDocsCollector topDocsCollector; TopDocsCollector topDocsCollector;
@ -197,13 +238,17 @@ public final class InnerHitsContext {
topDocsCollector = TopScoreDocCollector.create(topN); topDocsCollector = TopScoreDocCollector.create(topN);
} }
try { try {
context.searcher().search(q, topDocsCollector); for (LeafReaderContext ctx : context.searcher().getIndexReader().leaves()) {
intersect(weight, innerHitQueryWeight, topDocsCollector, ctx);
}
} finally { } finally {
clearReleasables(Lifetime.COLLECTION); clearReleasables(Lifetime.COLLECTION);
} }
return topDocsCollector.topDocs(from(), size()); result[i] = topDocsCollector.topDocs(from(), size());
} }
} }
return result;
}
private boolean isParentHit(SearchHit hit) { private boolean isParentHit(SearchHit hit) {
return hit.getType().equals(documentMapper.parentFieldMapper().type()); return hit.getType().equals(documentMapper.parentFieldMapper().type());
@ -214,4 +259,42 @@ public final class InnerHitsContext {
return documentMapper.type().equals(hitDocumentMapper.parentFieldMapper().type()); return documentMapper.type().equals(hitDocumentMapper.parentFieldMapper().type());
} }
} }
static void intersect(Weight weight, Weight innerHitQueryWeight, Collector collector, LeafReaderContext ctx) throws IOException {
ScorerSupplier scorerSupplier = weight.scorerSupplier(ctx);
if (scorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer scorer = scorerSupplier.get(true);
ScorerSupplier innerHitQueryScorerSupplier = innerHitQueryWeight.scorerSupplier(ctx);
if (innerHitQueryScorerSupplier == null) {
return;
}
// use random access since this scorer will be consumed on a minority of documents
Scorer innerHitQueryScorer = innerHitQueryScorerSupplier.get(true);
final LeafCollector leafCollector;
try {
leafCollector = collector.getLeafCollector(ctx);
// Just setting the innerHitQueryScorer is ok, because that is the actual scoring part of the query
leafCollector.setScorer(innerHitQueryScorer);
} catch (CollectionTerminatedException e) {
return;
}
try {
Bits acceptDocs = ctx.reader().getLiveDocs();
DocIdSetIterator iterator = ConjunctionDISI.intersectIterators(Arrays.asList(innerHitQueryScorer.iterator(),
scorer.iterator()));
for (int docId = iterator.nextDoc(); docId < DocIdSetIterator.NO_MORE_DOCS; docId = iterator.nextDoc()) {
if (acceptDocs == null || acceptDocs.get(docId)) {
leafCollector.collect(docId);
}
}
} catch (CollectionTerminatedException e) {
// ignore and continue
}
}
} }

View File

@ -22,12 +22,11 @@ package org.elasticsearch.search.fetch.subphase;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.fetch.FetchPhase; import org.elasticsearch.search.fetch.FetchPhase;
import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.FetchSubPhase; import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
@ -43,31 +42,34 @@ public final class InnerHitsFetchSubPhase implements FetchSubPhase {
} }
@Override @Override
public void hitExecute(SearchContext context, HitContext hitContext) { public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOException {
if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) { if ((context.innerHits() != null && context.innerHits().getInnerHits().size() > 0) == false) {
return; return;
} }
Map<String, SearchHits> results = new HashMap<>();
for (Map.Entry<String, InnerHitsContext.BaseInnerHits> entry : context.innerHits().getInnerHits().entrySet()) { for (Map.Entry<String, InnerHitsContext.BaseInnerHits> entry : context.innerHits().getInnerHits().entrySet()) {
InnerHitsContext.BaseInnerHits innerHits = entry.getValue(); InnerHitsContext.BaseInnerHits innerHits = entry.getValue();
TopDocs topDocs; TopDocs[] topDocs = innerHits.topDocs(hits);
try { for (int i = 0; i < hits.length; i++) {
topDocs = innerHits.topDocs(context, hitContext); SearchHit hit = hits[i];
} catch (IOException e) { TopDocs topDoc = topDocs[i];
throw ExceptionsHelper.convertToElastic(e);
Map<String, SearchHits> results = hit.getInnerHits();
if (results == null) {
hit.setInnerHits(results = new HashMap<>());
} }
innerHits.queryResult().topDocs(topDocs, innerHits.sort() == null ? null : innerHits.sort().formats); innerHits.queryResult().topDocs(topDoc, innerHits.sort() == null ? null : innerHits.sort().formats);
int[] docIdsToLoad = new int[topDocs.scoreDocs.length]; int[] docIdsToLoad = new int[topDoc.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) { for (int j = 0; j < topDoc.scoreDocs.length; j++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc; docIdsToLoad[j] = topDoc.scoreDocs[j].doc;
} }
innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length); innerHits.docIdsToLoad(docIdsToLoad, 0, docIdsToLoad.length);
fetchPhase.execute(innerHits); fetchPhase.execute(innerHits);
FetchSearchResult fetchResult = innerHits.fetchResult(); FetchSearchResult fetchResult = innerHits.fetchResult();
SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits(); SearchHit[] internalHits = fetchResult.fetchResult().hits().internalHits();
for (int i = 0; i < internalHits.length; i++) { for (int j = 0; j < internalHits.length; j++) {
ScoreDoc scoreDoc = topDocs.scoreDocs[i]; ScoreDoc scoreDoc = topDoc.scoreDocs[j];
SearchHit searchHitFields = internalHits[i]; SearchHit searchHitFields = internalHits[j];
searchHitFields.score(scoreDoc.score); searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc) { if (scoreDoc instanceof FieldDoc) {
FieldDoc fieldDoc = (FieldDoc) scoreDoc; FieldDoc fieldDoc = (FieldDoc) scoreDoc;
@ -76,6 +78,6 @@ public final class InnerHitsFetchSubPhase implements FetchSubPhase {
} }
results.put(entry.getKey(), fetchResult.hits()); results.put(entry.getKey(), fetchResult.hits());
} }
hitContext.hit().setInnerHits(results); }
} }
} }