From 083df0a86cfbf06bcf5a93e2f14a5e66e7e424a0 Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Tue, 30 Oct 2012 22:36:43 +0100 Subject: [PATCH] lucene 4: Upgraded o.e.search.dfs package. #2 --- .../controller/SearchPhaseController.java | 25 ++++++++-- .../search/dfs/AggregatedDfs.java | 46 +++++++++++++------ .../search/dfs/CachedDfSource.java | 19 ++++---- .../elasticsearch/search/dfs/DfsPhase.java | 19 ++++++-- .../search/dfs/DfsSearchResult.java | 32 ++++++++++++- 5 files changed, 109 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/elasticsearch/search/controller/SearchPhaseController.java b/src/main/java/org/elasticsearch/search/controller/SearchPhaseController.java index eb105c34b3a..b9e2ffabbc1 100644 --- a/src/main/java/org/elasticsearch/search/controller/SearchPhaseController.java +++ b/src/main/java/org/elasticsearch/search/controller/SearchPhaseController.java @@ -88,21 +88,36 @@ public class SearchPhaseController extends AbstractComponent { } public AggregatedDfs aggregateDfs(Iterable results) { - TMap dfMap = new ExtTHashMap(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR); + TMap termStatistics = new ExtTHashMap(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR); + TMap fieldStatistics = new ExtTHashMap(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR); long aggMaxDoc = 0; for (DfsSearchResult result : results) { for (int i = 0; i < result.termStatistics().length; i++) { - TermStatistics existing = dfMap.get(result.terms()[i]); + TermStatistics existing = termStatistics.get(result.terms()[i]); if (existing != null) { - dfMap.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq())); + termStatistics.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq())); } else { - dfMap.put(result.terms()[i], result.termStatistics()[i]); + termStatistics.put(result.terms()[i], result.termStatistics()[i]); } } + for (Map.Entry entry : result.fieldStatistics().entrySet()) { + CollectionStatistics existing = fieldStatistics.get(entry.getKey()); + if (existing != null) { + CollectionStatistics merged = new CollectionStatistics( + entry.getKey(), existing.maxDoc() + entry.getValue().maxDoc(), + existing.docCount() + entry.getValue().docCount(), + existing.sumTotalTermFreq() + entry.getValue().sumTotalTermFreq(), + existing.sumDocFreq() + entry.getValue().sumDocFreq() + ); + fieldStatistics.put(entry.getKey(), merged); + } else { + fieldStatistics.put(entry.getKey(), entry.getValue()); + } + } aggMaxDoc += result.maxDoc(); } - return new AggregatedDfs(dfMap, aggMaxDoc); + return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc); } public ShardDoc[] sortDocs(Collection results1) { diff --git a/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java b/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java index 1f0e8c348c5..6b750e153e6 100644 --- a/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java +++ b/src/main/java/org/elasticsearch/search/dfs/AggregatedDfs.java @@ -20,16 +20,14 @@ package org.elasticsearch.search.dfs; import gnu.trove.impl.Constants; -import gnu.trove.iterator.TObjectIntIterator; import gnu.trove.map.TMap; -import gnu.trove.map.hash.TObjectIntHashMap; import org.apache.lucene.index.Term; +import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.TermStatistics; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; import org.elasticsearch.common.trove.ExtTHashMap; -import org.elasticsearch.common.trove.ExtTObjectIntHasMap; import java.io.IOException; import java.util.Map; @@ -39,21 +37,26 @@ import java.util.Map; */ public class AggregatedDfs implements Streamable { - private TMap dfMap; - + private TMap termStatistics; + private TMap fieldStatistics; private long maxDoc; private AggregatedDfs() { } - public AggregatedDfs(TMap dfMap, long maxDoc) { - this.dfMap = dfMap; + public AggregatedDfs(TMap termStatistics, TMap fieldStatistics, long maxDoc) { + this.termStatistics = termStatistics; + this.fieldStatistics = fieldStatistics; this.maxDoc = maxDoc; } - public TMap dfMap() { - return dfMap; + public TMap termStatistics() { + return termStatistics; + } + + public TMap fieldStatistics() { + return fieldStatistics; } public long maxDoc() { @@ -69,20 +72,26 @@ public class AggregatedDfs implements Streamable { @Override public void readFrom(StreamInput in) throws IOException { int size = in.readVInt(); - dfMap = new ExtTHashMap(size, Constants.DEFAULT_LOAD_FACTOR); + termStatistics = new ExtTHashMap(size, Constants.DEFAULT_LOAD_FACTOR); for (int i = 0; i < size; i++) { Term term = new Term(in.readString(), in.readBytesRef()); TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong()); - dfMap.put(term, stats); + termStatistics.put(term, stats); + } + size = in.readVInt(); + fieldStatistics = new ExtTHashMap(size, Constants.DEFAULT_LOAD_FACTOR); + for (int i = 0; i < size; i++) { + String field = in.readString(); + CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong()); + fieldStatistics.put(field, stats); } maxDoc = in.readVLong(); } @Override public void writeTo(final StreamOutput out) throws IOException { - out.writeVInt(dfMap.size()); - - for (Map.Entry termTermStatisticsEntry : dfMap.entrySet()) { + out.writeVInt(termStatistics.size()); + for (Map.Entry termTermStatisticsEntry : termStatistics.entrySet()) { Term term = termTermStatisticsEntry.getKey(); out.writeString(term.field()); out.writeBytesRef(term.bytes()); @@ -92,6 +101,15 @@ public class AggregatedDfs implements Streamable { out.writeVLong(stats.totalTermFreq()); } + out.writeVInt(fieldStatistics.size()); + for (Map.Entry entry : fieldStatistics.entrySet()) { + out.writeString(entry.getKey()); + out.writeVLong(entry.getValue().maxDoc()); + out.writeVLong(entry.getValue().docCount()); + out.writeVLong(entry.getValue().sumTotalTermFreq()); + out.writeVLong(entry.getValue().sumDocFreq()); + } + out.writeVLong(maxDoc); } } diff --git a/src/main/java/org/elasticsearch/search/dfs/CachedDfSource.java b/src/main/java/org/elasticsearch/search/dfs/CachedDfSource.java index 1ecc2ac0eff..e28ff78268c 100644 --- a/src/main/java/org/elasticsearch/search/dfs/CachedDfSource.java +++ b/src/main/java/org/elasticsearch/search/dfs/CachedDfSource.java @@ -24,7 +24,6 @@ import org.apache.lucene.index.*; import org.apache.lucene.search.*; import org.apache.lucene.search.similarities.Similarity; import org.elasticsearch.ElasticSearchIllegalArgumentException; -import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; import java.util.List; @@ -34,25 +33,25 @@ import java.util.List; */ public class CachedDfSource extends IndexSearcher { - private final AggregatedDfs dfs; + private final AggregatedDfs aggregatedDfs; private final int maxDoc; - public CachedDfSource(IndexReader reader, AggregatedDfs dfs, Similarity similarity) throws IOException { + public CachedDfSource(IndexReader reader, AggregatedDfs aggregatedDfs, Similarity similarity) throws IOException { super(reader); - this.dfs = dfs; + this.aggregatedDfs = aggregatedDfs; setSimilarity(similarity); - if (dfs.maxDoc() > Integer.MAX_VALUE) { + if (aggregatedDfs.maxDoc() > Integer.MAX_VALUE) { maxDoc = Integer.MAX_VALUE; } else { - maxDoc = (int) dfs.maxDoc(); + maxDoc = (int) aggregatedDfs.maxDoc(); } } @Override public TermStatistics termStatistics(Term term, TermContext context) throws IOException { - TermStatistics termStatistics = dfs.dfMap().get(term); + TermStatistics termStatistics = aggregatedDfs.termStatistics().get(term); if (termStatistics == null) { throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term); } @@ -61,7 +60,11 @@ public class CachedDfSource extends IndexSearcher { @Override public CollectionStatistics collectionStatistics(String field) throws IOException { - throw new UnsupportedOperationException(); + CollectionStatistics collectionStatistics = aggregatedDfs.fieldStatistics().get(field); + if (collectionStatistics == null) { + throw new ElasticSearchIllegalArgumentException("Not distributed collection statistics for field: " + field); + } + return collectionStatistics; } public int maxDoc() { diff --git a/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java b/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java index f4394387060..7fd8d20f594 100644 --- a/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java +++ b/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java @@ -20,17 +20,23 @@ package org.elasticsearch.search.dfs; import com.google.common.collect.ImmutableMap; +import gnu.trove.map.TMap; import gnu.trove.set.hash.THashSet; import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermContext; +import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.TermStatistics; +import org.elasticsearch.common.trove.ExtTHashMap; import org.elasticsearch.common.util.concurrent.ThreadLocals; import org.elasticsearch.search.SearchParseElement; import org.elasticsearch.search.SearchPhase; import org.elasticsearch.search.internal.SearchContext; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; /** * @@ -71,11 +77,16 @@ public class DfsPhase implements SearchPhase { termStatistics[i] = context.searcher().termStatistics(terms[i], termContext); } - // TODO: LUCENE 4 UPGRADE - add collection stats for each unique field, for distributed scoring -// context.searcher().collectionStatistics() + TMap fieldStatistics = new ExtTHashMap(); + for (Term term : terms) { + if (!fieldStatistics.containsKey(term.field())) { + fieldStatistics.put(term.field(), context.searcher().collectionStatistics(term.field())); + } + } - context.dfsResult().termsAndFreqs(terms, termStatistics); - context.dfsResult().maxDoc(context.searcher().getIndexReader().maxDoc()); + context.dfsResult().termsStatistics(terms, termStatistics) + .fieldStatistics(fieldStatistics) + .maxDoc(context.searcher().getIndexReader().maxDoc()); } catch (Exception e) { throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e); } diff --git a/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java b/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java index 75de8e72065..680e99b2a6a 100644 --- a/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java +++ b/src/main/java/org/elasticsearch/search/dfs/DfsSearchResult.java @@ -19,16 +19,20 @@ package org.elasticsearch.search.dfs; +import gnu.trove.map.TMap; import org.apache.lucene.index.Term; +import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.TermStatistics; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.trove.ExtTHashMap; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.transport.TransportResponse; import java.io.IOException; +import java.util.Map; /** * @@ -42,6 +46,7 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes private long id; private Term[] terms; private TermStatistics[] termStatistics; + private TMap fieldStatistics = new ExtTHashMap(); private int maxDoc; public DfsSearchResult() { @@ -75,12 +80,17 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes return maxDoc; } - public DfsSearchResult termsAndFreqs(Term[] terms, TermStatistics[] termStatistics) { + public DfsSearchResult termsStatistics(Term[] terms, TermStatistics[] termStatistics) { this.terms = terms; this.termStatistics = termStatistics; return this; } + public DfsSearchResult fieldStatistics(TMap fieldStatistics) { + this.fieldStatistics = fieldStatistics; + return this; + } + public Term[] terms() { return terms; } @@ -89,6 +99,10 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes return termStatistics; } + public TMap fieldStatistics() { + return fieldStatistics; + } + public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException { DfsSearchResult result = new DfsSearchResult(); result.readFrom(in); @@ -121,6 +135,13 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq); } } + int numFieldStatistics = in.readVInt(); + for (int i = 0; i < numFieldStatistics; i++) { + String field = in.readString(); + CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong()); + fieldStatistics.put(field, stats); + } + maxDoc = in.readVInt(); } @@ -139,6 +160,15 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes out.writeVLong(termStatistic.docFreq()); out.writeVLong(termStatistic.totalTermFreq()); } + out.writeVInt(fieldStatistics.size()); + for (Map.Entry entry : fieldStatistics.entrySet()) { + out.writeString(entry.getKey()); + out.writeVLong(entry.getValue().maxDoc()); + out.writeVLong(entry.getValue().docCount()); + out.writeVLong(entry.getValue().sumTotalTermFreq()); + out.writeVLong(entry.getValue().sumDocFreq()); + } out.writeVInt(maxDoc); } + }