lucene 4: Upgraded o.e.search.dfs package. #2

This commit is contained in:
Martijn van Groningen 2012-10-30 22:36:43 +01:00 committed by Shay Banon
parent 5f942ef63d
commit 083df0a86c
5 changed files with 109 additions and 32 deletions

View File

@ -88,21 +88,36 @@ public class SearchPhaseController extends AbstractComponent {
}
public AggregatedDfs aggregateDfs(Iterable<DfsSearchResult> results) {
TMap<Term, TermStatistics> dfMap = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
TMap<Term, TermStatistics> termStatistics = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
long aggMaxDoc = 0;
for (DfsSearchResult result : results) {
for (int i = 0; i < result.termStatistics().length; i++) {
TermStatistics existing = dfMap.get(result.terms()[i]);
TermStatistics existing = termStatistics.get(result.terms()[i]);
if (existing != null) {
dfMap.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
termStatistics.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
} else {
dfMap.put(result.terms()[i], result.termStatistics()[i]);
termStatistics.put(result.terms()[i], result.termStatistics()[i]);
}
}
for (Map.Entry<String, CollectionStatistics> entry : result.fieldStatistics().entrySet()) {
CollectionStatistics existing = fieldStatistics.get(entry.getKey());
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
entry.getKey(), existing.maxDoc() + entry.getValue().maxDoc(),
existing.docCount() + entry.getValue().docCount(),
existing.sumTotalTermFreq() + entry.getValue().sumTotalTermFreq(),
existing.sumDocFreq() + entry.getValue().sumDocFreq()
);
fieldStatistics.put(entry.getKey(), merged);
} else {
fieldStatistics.put(entry.getKey(), entry.getValue());
}
}
aggMaxDoc += result.maxDoc();
}
return new AggregatedDfs(dfMap, aggMaxDoc);
return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
}
public ShardDoc[] sortDocs(Collection<? extends QuerySearchResultProvider> results1) {

View File

@ -20,16 +20,14 @@
package org.elasticsearch.search.dfs;
import gnu.trove.impl.Constants;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.TMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.trove.ExtTObjectIntHasMap;
import java.io.IOException;
import java.util.Map;
@ -39,21 +37,26 @@ import java.util.Map;
*/
public class AggregatedDfs implements Streamable {
private TMap<Term, TermStatistics> dfMap;
private TMap<Term, TermStatistics> termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics;
private long maxDoc;
private AggregatedDfs() {
}
public AggregatedDfs(TMap<Term, TermStatistics> dfMap, long maxDoc) {
this.dfMap = dfMap;
public AggregatedDfs(TMap<Term, TermStatistics> termStatistics, TMap<String, CollectionStatistics> fieldStatistics, long maxDoc) {
this.termStatistics = termStatistics;
this.fieldStatistics = fieldStatistics;
this.maxDoc = maxDoc;
}
public TMap<Term, TermStatistics> dfMap() {
return dfMap;
public TMap<Term, TermStatistics> termStatistics() {
return termStatistics;
}
public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
}
public long maxDoc() {
@ -69,20 +72,26 @@ public class AggregatedDfs implements Streamable {
@Override
public void readFrom(StreamInput in) throws IOException {
int size = in.readVInt();
dfMap = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
termStatistics = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
Term term = new Term(in.readString(), in.readBytesRef());
TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong());
dfMap.put(term, stats);
termStatistics.put(term, stats);
}
size = in.readVInt();
fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
}
maxDoc = in.readVLong();
}
@Override
public void writeTo(final StreamOutput out) throws IOException {
out.writeVInt(dfMap.size());
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : dfMap.entrySet()) {
out.writeVInt(termStatistics.size());
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : termStatistics.entrySet()) {
Term term = termTermStatisticsEntry.getKey();
out.writeString(term.field());
out.writeBytesRef(term.bytes());
@ -92,6 +101,15 @@ public class AggregatedDfs implements Streamable {
out.writeVLong(stats.totalTermFreq());
}
out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}
out.writeVLong(maxDoc);
}
}

View File

@ -24,7 +24,6 @@ import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
@ -34,25 +33,25 @@ import java.util.List;
*/
public class CachedDfSource extends IndexSearcher {
private final AggregatedDfs dfs;
private final AggregatedDfs aggregatedDfs;
private final int maxDoc;
public CachedDfSource(IndexReader reader, AggregatedDfs dfs, Similarity similarity) throws IOException {
public CachedDfSource(IndexReader reader, AggregatedDfs aggregatedDfs, Similarity similarity) throws IOException {
super(reader);
this.dfs = dfs;
this.aggregatedDfs = aggregatedDfs;
setSimilarity(similarity);
if (dfs.maxDoc() > Integer.MAX_VALUE) {
if (aggregatedDfs.maxDoc() > Integer.MAX_VALUE) {
maxDoc = Integer.MAX_VALUE;
} else {
maxDoc = (int) dfs.maxDoc();
maxDoc = (int) aggregatedDfs.maxDoc();
}
}
@Override
public TermStatistics termStatistics(Term term, TermContext context) throws IOException {
TermStatistics termStatistics = dfs.dfMap().get(term);
TermStatistics termStatistics = aggregatedDfs.termStatistics().get(term);
if (termStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term);
}
@ -61,7 +60,11 @@ public class CachedDfSource extends IndexSearcher {
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
throw new UnsupportedOperationException();
CollectionStatistics collectionStatistics = aggregatedDfs.fieldStatistics().get(field);
if (collectionStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed collection statistics for field: " + field);
}
return collectionStatistics;
}
public int maxDoc() {

View File

@ -20,17 +20,23 @@
package org.elasticsearch.search.dfs;
import com.google.common.collect.ImmutableMap;
import gnu.trove.map.TMap;
import gnu.trove.set.hash.THashSet;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.util.concurrent.ThreadLocals;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchPhase;
import org.elasticsearch.search.internal.SearchContext;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
*
@ -71,11 +77,16 @@ public class DfsPhase implements SearchPhase {
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
}
// TODO: LUCENE 4 UPGRADE - add collection stats for each unique field, for distributed scoring
// context.searcher().collectionStatistics()
TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
for (Term term : terms) {
if (!fieldStatistics.containsKey(term.field())) {
fieldStatistics.put(term.field(), context.searcher().collectionStatistics(term.field()));
}
}
context.dfsResult().termsAndFreqs(terms, termStatistics);
context.dfsResult().maxDoc(context.searcher().getIndexReader().maxDoc());
context.dfsResult().termsStatistics(terms, termStatistics)
.fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) {
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
}

View File

@ -19,16 +19,20 @@
package org.elasticsearch.search.dfs;
import gnu.trove.map.TMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.transport.TransportResponse;
import java.io.IOException;
import java.util.Map;
/**
*
@ -42,6 +46,7 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
private long id;
private Term[] terms;
private TermStatistics[] termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
private int maxDoc;
public DfsSearchResult() {
@ -75,12 +80,17 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return maxDoc;
}
public DfsSearchResult termsAndFreqs(Term[] terms, TermStatistics[] termStatistics) {
public DfsSearchResult termsStatistics(Term[] terms, TermStatistics[] termStatistics) {
this.terms = terms;
this.termStatistics = termStatistics;
return this;
}
public DfsSearchResult fieldStatistics(TMap<String, CollectionStatistics> fieldStatistics) {
this.fieldStatistics = fieldStatistics;
return this;
}
public Term[] terms() {
return terms;
}
@ -89,6 +99,10 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return termStatistics;
}
public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
}
public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException {
DfsSearchResult result = new DfsSearchResult();
result.readFrom(in);
@ -121,6 +135,13 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq);
}
}
int numFieldStatistics = in.readVInt();
for (int i = 0; i < numFieldStatistics; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
}
maxDoc = in.readVInt();
}
@ -139,6 +160,15 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
out.writeVLong(termStatistic.docFreq());
out.writeVLong(termStatistic.totalTermFreq());
}
out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}
out.writeVInt(maxDoc);
}
}