lucene 4: Upgraded o.e.search.dfs package. #2

This commit is contained in:
Martijn van Groningen 2012-10-30 22:36:43 +01:00 committed by Shay Banon
parent 5f942ef63d
commit 083df0a86c
5 changed files with 109 additions and 32 deletions

View File

@ -88,21 +88,36 @@ public class SearchPhaseController extends AbstractComponent {
} }
public AggregatedDfs aggregateDfs(Iterable<DfsSearchResult> results) { public AggregatedDfs aggregateDfs(Iterable<DfsSearchResult> results) {
TMap<Term, TermStatistics> dfMap = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR); TMap<Term, TermStatistics> termStatistics = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
long aggMaxDoc = 0; long aggMaxDoc = 0;
for (DfsSearchResult result : results) { for (DfsSearchResult result : results) {
for (int i = 0; i < result.termStatistics().length; i++) { for (int i = 0; i < result.termStatistics().length; i++) {
TermStatistics existing = dfMap.get(result.terms()[i]); TermStatistics existing = termStatistics.get(result.terms()[i]);
if (existing != null) { if (existing != null) {
dfMap.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq())); termStatistics.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
} else { } else {
dfMap.put(result.terms()[i], result.termStatistics()[i]); termStatistics.put(result.terms()[i], result.termStatistics()[i]);
} }
} }
for (Map.Entry<String, CollectionStatistics> entry : result.fieldStatistics().entrySet()) {
CollectionStatistics existing = fieldStatistics.get(entry.getKey());
if (existing != null) {
CollectionStatistics merged = new CollectionStatistics(
entry.getKey(), existing.maxDoc() + entry.getValue().maxDoc(),
existing.docCount() + entry.getValue().docCount(),
existing.sumTotalTermFreq() + entry.getValue().sumTotalTermFreq(),
existing.sumDocFreq() + entry.getValue().sumDocFreq()
);
fieldStatistics.put(entry.getKey(), merged);
} else {
fieldStatistics.put(entry.getKey(), entry.getValue());
}
}
aggMaxDoc += result.maxDoc(); aggMaxDoc += result.maxDoc();
} }
return new AggregatedDfs(dfMap, aggMaxDoc); return new AggregatedDfs(termStatistics, fieldStatistics, aggMaxDoc);
} }
public ShardDoc[] sortDocs(Collection<? extends QuerySearchResultProvider> results1) { public ShardDoc[] sortDocs(Collection<? extends QuerySearchResultProvider> results1) {

View File

@ -20,16 +20,14 @@
package org.elasticsearch.search.dfs; package org.elasticsearch.search.dfs;
import gnu.trove.impl.Constants; import gnu.trove.impl.Constants;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.TMap; import gnu.trove.map.TMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable; import org.elasticsearch.common.io.stream.Streamable;
import org.elasticsearch.common.trove.ExtTHashMap; import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.trove.ExtTObjectIntHasMap;
import java.io.IOException; import java.io.IOException;
import java.util.Map; import java.util.Map;
@ -39,21 +37,26 @@ import java.util.Map;
*/ */
public class AggregatedDfs implements Streamable { public class AggregatedDfs implements Streamable {
private TMap<Term, TermStatistics> dfMap; private TMap<Term, TermStatistics> termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics;
private long maxDoc; private long maxDoc;
private AggregatedDfs() { private AggregatedDfs() {
} }
public AggregatedDfs(TMap<Term, TermStatistics> dfMap, long maxDoc) { public AggregatedDfs(TMap<Term, TermStatistics> termStatistics, TMap<String, CollectionStatistics> fieldStatistics, long maxDoc) {
this.dfMap = dfMap; this.termStatistics = termStatistics;
this.fieldStatistics = fieldStatistics;
this.maxDoc = maxDoc; this.maxDoc = maxDoc;
} }
public TMap<Term, TermStatistics> dfMap() { public TMap<Term, TermStatistics> termStatistics() {
return dfMap; return termStatistics;
}
public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
} }
public long maxDoc() { public long maxDoc() {
@ -69,20 +72,26 @@ public class AggregatedDfs implements Streamable {
@Override @Override
public void readFrom(StreamInput in) throws IOException { public void readFrom(StreamInput in) throws IOException {
int size = in.readVInt(); int size = in.readVInt();
dfMap = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR); termStatistics = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
Term term = new Term(in.readString(), in.readBytesRef()); Term term = new Term(in.readString(), in.readBytesRef());
TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong()); TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong());
dfMap.put(term, stats); termStatistics.put(term, stats);
}
size = in.readVInt();
fieldStatistics = new ExtTHashMap<String, CollectionStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
} }
maxDoc = in.readVLong(); maxDoc = in.readVLong();
} }
@Override @Override
public void writeTo(final StreamOutput out) throws IOException { public void writeTo(final StreamOutput out) throws IOException {
out.writeVInt(dfMap.size()); out.writeVInt(termStatistics.size());
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : termStatistics.entrySet()) {
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : dfMap.entrySet()) {
Term term = termTermStatisticsEntry.getKey(); Term term = termTermStatisticsEntry.getKey();
out.writeString(term.field()); out.writeString(term.field());
out.writeBytesRef(term.bytes()); out.writeBytesRef(term.bytes());
@ -92,6 +101,15 @@ public class AggregatedDfs implements Streamable {
out.writeVLong(stats.totalTermFreq()); out.writeVLong(stats.totalTermFreq());
} }
out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}
out.writeVLong(maxDoc); out.writeVLong(maxDoc);
} }
} }

View File

@ -24,7 +24,6 @@ import org.apache.lucene.index.*;
import org.apache.lucene.search.*; import org.apache.lucene.search.*;
import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.ElasticSearchIllegalArgumentException; import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -34,25 +33,25 @@ import java.util.List;
*/ */
public class CachedDfSource extends IndexSearcher { public class CachedDfSource extends IndexSearcher {
private final AggregatedDfs dfs; private final AggregatedDfs aggregatedDfs;
private final int maxDoc; private final int maxDoc;
public CachedDfSource(IndexReader reader, AggregatedDfs dfs, Similarity similarity) throws IOException { public CachedDfSource(IndexReader reader, AggregatedDfs aggregatedDfs, Similarity similarity) throws IOException {
super(reader); super(reader);
this.dfs = dfs; this.aggregatedDfs = aggregatedDfs;
setSimilarity(similarity); setSimilarity(similarity);
if (dfs.maxDoc() > Integer.MAX_VALUE) { if (aggregatedDfs.maxDoc() > Integer.MAX_VALUE) {
maxDoc = Integer.MAX_VALUE; maxDoc = Integer.MAX_VALUE;
} else { } else {
maxDoc = (int) dfs.maxDoc(); maxDoc = (int) aggregatedDfs.maxDoc();
} }
} }
@Override @Override
public TermStatistics termStatistics(Term term, TermContext context) throws IOException { public TermStatistics termStatistics(Term term, TermContext context) throws IOException {
TermStatistics termStatistics = dfs.dfMap().get(term); TermStatistics termStatistics = aggregatedDfs.termStatistics().get(term);
if (termStatistics == null) { if (termStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term); throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term);
} }
@ -61,7 +60,11 @@ public class CachedDfSource extends IndexSearcher {
@Override @Override
public CollectionStatistics collectionStatistics(String field) throws IOException { public CollectionStatistics collectionStatistics(String field) throws IOException {
throw new UnsupportedOperationException(); CollectionStatistics collectionStatistics = aggregatedDfs.fieldStatistics().get(field);
if (collectionStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed collection statistics for field: " + field);
}
return collectionStatistics;
} }
public int maxDoc() { public int maxDoc() {

View File

@ -20,17 +20,23 @@
package org.elasticsearch.search.dfs; package org.elasticsearch.search.dfs;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import gnu.trove.map.TMap;
import gnu.trove.set.hash.THashSet; import gnu.trove.set.hash.THashSet;
import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext; import org.apache.lucene.index.TermContext;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.util.concurrent.ThreadLocals; import org.elasticsearch.common.util.concurrent.ThreadLocals;
import org.elasticsearch.search.SearchParseElement; import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchPhase; import org.elasticsearch.search.SearchPhase;
import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.SearchContext;
import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
/** /**
* *
@ -71,11 +77,16 @@ public class DfsPhase implements SearchPhase {
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext); termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
} }
// TODO: LUCENE 4 UPGRADE - add collection stats for each unique field, for distributed scoring TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
// context.searcher().collectionStatistics() for (Term term : terms) {
if (!fieldStatistics.containsKey(term.field())) {
fieldStatistics.put(term.field(), context.searcher().collectionStatistics(term.field()));
}
}
context.dfsResult().termsAndFreqs(terms, termStatistics); context.dfsResult().termsStatistics(terms, termStatistics)
context.dfsResult().maxDoc(context.searcher().getIndexReader().maxDoc()); .fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) { } catch (Exception e) {
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e); throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
} }

View File

@ -19,16 +19,20 @@
package org.elasticsearch.search.dfs; package org.elasticsearch.search.dfs;
import gnu.trove.map.TMap;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.transport.TransportResponse; import org.elasticsearch.transport.TransportResponse;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
/** /**
* *
@ -42,6 +46,7 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
private long id; private long id;
private Term[] terms; private Term[] terms;
private TermStatistics[] termStatistics; private TermStatistics[] termStatistics;
private TMap<String, CollectionStatistics> fieldStatistics = new ExtTHashMap<String, CollectionStatistics>();
private int maxDoc; private int maxDoc;
public DfsSearchResult() { public DfsSearchResult() {
@ -75,12 +80,17 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return maxDoc; return maxDoc;
} }
public DfsSearchResult termsAndFreqs(Term[] terms, TermStatistics[] termStatistics) { public DfsSearchResult termsStatistics(Term[] terms, TermStatistics[] termStatistics) {
this.terms = terms; this.terms = terms;
this.termStatistics = termStatistics; this.termStatistics = termStatistics;
return this; return this;
} }
public DfsSearchResult fieldStatistics(TMap<String, CollectionStatistics> fieldStatistics) {
this.fieldStatistics = fieldStatistics;
return this;
}
public Term[] terms() { public Term[] terms() {
return terms; return terms;
} }
@ -89,6 +99,10 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return termStatistics; return termStatistics;
} }
public TMap<String, CollectionStatistics> fieldStatistics() {
return fieldStatistics;
}
public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException { public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException {
DfsSearchResult result = new DfsSearchResult(); DfsSearchResult result = new DfsSearchResult();
result.readFrom(in); result.readFrom(in);
@ -121,6 +135,13 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq); termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq);
} }
} }
int numFieldStatistics = in.readVInt();
for (int i = 0; i < numFieldStatistics; i++) {
String field = in.readString();
CollectionStatistics stats = new CollectionStatistics(field, in.readVLong(), in.readVLong(), in.readVLong(), in.readVLong());
fieldStatistics.put(field, stats);
}
maxDoc = in.readVInt(); maxDoc = in.readVInt();
} }
@ -139,6 +160,15 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
out.writeVLong(termStatistic.docFreq()); out.writeVLong(termStatistic.docFreq());
out.writeVLong(termStatistic.totalTermFreq()); out.writeVLong(termStatistic.totalTermFreq());
} }
out.writeVInt(fieldStatistics.size());
for (Map.Entry<String, CollectionStatistics> entry : fieldStatistics.entrySet()) {
out.writeString(entry.getKey());
out.writeVLong(entry.getValue().maxDoc());
out.writeVLong(entry.getValue().docCount());
out.writeVLong(entry.getValue().sumTotalTermFreq());
out.writeVLong(entry.getValue().sumDocFreq());
}
out.writeVInt(maxDoc); out.writeVInt(maxDoc);
} }
} }