lucene 4: Upgraded o.e.search.dfs package. (Distributed idf)

This commit is contained in:
Martijn van Groningen 2012-10-30 20:44:09 +01:00 committed by Shay Banon
parent fd2cf776d8
commit 5f942ef63d
8 changed files with 136 additions and 64 deletions

View File

@ -19,6 +19,7 @@
package org.elasticsearch.common.io.stream;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
@ -84,6 +85,14 @@ public abstract class StreamInput extends InputStream {
return new BytesArray(bytes, 0, length);
}
public BytesRef readBytesRef() throws IOException {
int length = readVInt();
int offset = readVInt();
byte[] bytes = new byte[length];
readBytes(bytes, offset, length);
return new BytesRef(bytes, offset, length);
}
public void readFully(byte[] b) throws IOException {
readBytes(b, 0, b.length);
}

View File

@ -19,6 +19,7 @@
package org.elasticsearch.common.io.stream;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
@ -106,6 +107,16 @@ public abstract class StreamOutput extends OutputStream {
bytes.writeTo(this);
}
public void writeBytesRef(BytesRef bytes) throws IOException {
if (bytes == null) {
writeVInt(0);
return;
}
writeVInt(bytes.length);
writeVInt(bytes.offset);
write(bytes.bytes, bytes.offset, bytes.length);
}
public final void writeShort(short v) throws IOException {
writeByte((byte) (v >> 8));
writeByte((byte) v);

View File

@ -282,7 +282,7 @@ public class SearchService extends AbstractLifecycleComponent<SearchService> {
SearchContext context = findContext(request.id());
contextProcessing(context);
try {
context.searcher().dfSource(new CachedDfSource(request.dfs(), context.similarityService().defaultSearchSimilarity()));
context.searcher().dfSource(new CachedDfSource(context.searcher().getIndexReader(), request.dfs(), context.similarityService().defaultSearchSimilarity()));
} catch (IOException e) {
freeContext(context);
cleanContext(context);
@ -348,7 +348,7 @@ public class SearchService extends AbstractLifecycleComponent<SearchService> {
SearchContext context = findContext(request.id());
contextProcessing(context);
try {
context.searcher().dfSource(new CachedDfSource(request.dfs(), context.similarityService().defaultSearchSimilarity()));
context.searcher().dfSource(new CachedDfSource(context.searcher().getIndexReader(), request.dfs(), context.similarityService().defaultSearchSimilarity()));
} catch (IOException e) {
freeContext(context);
cleanContext(context);

View File

@ -24,6 +24,7 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import gnu.trove.impl.Constants;
import gnu.trove.map.TMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.*;
@ -32,6 +33,7 @@ import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.trove.ExtTIntArrayList;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.dfs.AggregatedDfs;
@ -86,11 +88,17 @@ public class SearchPhaseController extends AbstractComponent {
}
public AggregatedDfs aggregateDfs(Iterable<DfsSearchResult> results) {
TObjectIntHashMap<Term> dfMap = new TObjectIntHashMap<Term>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR, -1);
TMap<Term, TermStatistics> dfMap = new ExtTHashMap<Term, TermStatistics>(Constants.DEFAULT_CAPACITY, Constants.DEFAULT_LOAD_FACTOR);
long aggMaxDoc = 0;
for (DfsSearchResult result : results) {
for (int i = 0; i < result.freqs().length; i++) {
dfMap.adjustOrPutValue(result.terms()[i], result.freqs()[i], result.freqs()[i]);
for (int i = 0; i < result.termStatistics().length; i++) {
TermStatistics existing = dfMap.get(result.terms()[i]);
if (existing != null) {
dfMap.put(result.terms()[i], new TermStatistics(existing.term(), existing.docFreq() + result.termStatistics()[i].docFreq(), existing.totalTermFreq() + result.termStatistics()[i].totalTermFreq()));
} else {
dfMap.put(result.terms()[i], result.termStatistics()[i]);
}
}
aggMaxDoc += result.maxDoc();
}
@ -173,7 +181,7 @@ public class SearchPhaseController extends AbstractComponent {
if (fDoc.fields[i] != null) {
allValuesAreNull = false;
if (fDoc.fields[i] instanceof String) {
fieldDocs.fields[i] = new SortField(fieldDocs.fields[i].getField(), SortField.STRING, fieldDocs.fields[i].getReverse());
fieldDocs.fields[i] = new SortField(fieldDocs.fields[i].getField(), SortField.Type.STRING, fieldDocs.fields[i].getReverse());
}
resolvedField = true;
break;
@ -185,7 +193,7 @@ public class SearchPhaseController extends AbstractComponent {
}
if (!resolvedField && allValuesAreNull && fieldDocs.fields[i].getField() != null) {
// we did not manage to resolve a field (and its not score or doc, which have no field), and all the fields are null (which can only happen for STRING), make it a STRING
fieldDocs.fields[i] = new SortField(fieldDocs.fields[i].getField(), SortField.STRING, fieldDocs.fields[i].getReverse());
fieldDocs.fields[i] = new SortField(fieldDocs.fields[i].getField(), SortField.Type.STRING, fieldDocs.fields[i].getReverse());
}
}
queue = new ShardFieldDocSortedHitQueue(fieldDocs.fields, queueSize);
@ -270,7 +278,7 @@ public class SearchPhaseController extends AbstractComponent {
sorted = true;
TopFieldDocs fieldDocs = (TopFieldDocs) querySearchResult.queryResult().topDocs();
for (int i = 0; i < fieldDocs.fields.length; i++) {
if (fieldDocs.fields[i].getType() == SortField.SCORE) {
if (fieldDocs.fields[i].getType() == SortField.Type.SCORE) {
sortScoreIndex = i;
}
}

View File

@ -21,21 +21,25 @@ package org.elasticsearch.search.dfs;
import gnu.trove.impl.Constants;
import gnu.trove.iterator.TObjectIntIterator;
import gnu.trove.map.TMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
import org.elasticsearch.common.trove.ExtTHashMap;
import org.elasticsearch.common.trove.ExtTObjectIntHasMap;
import java.io.IOException;
import java.util.Map;
/**
*
*/
public class AggregatedDfs implements Streamable {
private TObjectIntHashMap<Term> dfMap;
private TMap<Term, TermStatistics> dfMap;
private long maxDoc;
@ -43,12 +47,12 @@ public class AggregatedDfs implements Streamable {
}
public AggregatedDfs(TObjectIntHashMap<Term> dfMap, long maxDoc) {
public AggregatedDfs(TMap<Term, TermStatistics> dfMap, long maxDoc) {
this.dfMap = dfMap;
this.maxDoc = maxDoc;
}
public TObjectIntHashMap<Term> dfMap() {
public TMap<Term, TermStatistics> dfMap() {
return dfMap;
}
@ -65,9 +69,11 @@ public class AggregatedDfs implements Streamable {
@Override
public void readFrom(StreamInput in) throws IOException {
int size = in.readVInt();
dfMap = new ExtTObjectIntHasMap<Term>(size, Constants.DEFAULT_LOAD_FACTOR, -1);
dfMap = new ExtTHashMap<Term, TermStatistics>(size, Constants.DEFAULT_LOAD_FACTOR);
for (int i = 0; i < size; i++) {
dfMap.put(new Term(in.readUTF(), in.readUTF()), in.readVInt());
Term term = new Term(in.readString(), in.readBytesRef());
TermStatistics stats = new TermStatistics(in.readBytesRef(), in.readVLong(), in.readVLong());
dfMap.put(term, stats);
}
maxDoc = in.readVLong();
}
@ -76,12 +82,16 @@ public class AggregatedDfs implements Streamable {
public void writeTo(final StreamOutput out) throws IOException {
out.writeVInt(dfMap.size());
for (TObjectIntIterator<Term> it = dfMap.iterator(); it.hasNext(); ) {
it.advance();
out.writeUTF(it.key().field());
out.writeUTF(it.key().text());
out.writeVInt(it.value());
for (Map.Entry<Term, TermStatistics> termTermStatisticsEntry : dfMap.entrySet()) {
Term term = termTermStatisticsEntry.getKey();
out.writeString(term.field());
out.writeBytesRef(term.bytes());
TermStatistics stats = termTermStatisticsEntry.getValue();
out.writeBytesRef(stats.term());
out.writeVLong(stats.docFreq());
out.writeVLong(stats.totalTermFreq());
}
out.writeVLong(maxDoc);
}
}

View File

@ -20,22 +20,26 @@
package org.elasticsearch.search.dfs;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldSelector;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import org.apache.lucene.search.similarities.Similarity;
import org.elasticsearch.ElasticSearchIllegalArgumentException;
import org.elasticsearch.search.internal.SearchContext;
import java.io.IOException;
import java.util.List;
/**
*
*/
public class CachedDfSource extends Searcher {
public class CachedDfSource extends IndexSearcher {
private final AggregatedDfs dfs;
private final int maxDoc;
public CachedDfSource(AggregatedDfs dfs, Similarity similarity) throws IOException {
public CachedDfSource(IndexReader reader, AggregatedDfs dfs, Similarity similarity) throws IOException {
super(reader);
this.dfs = dfs;
setSimilarity(similarity);
if (dfs.maxDoc() > Integer.MAX_VALUE) {
@ -45,21 +49,19 @@ public class CachedDfSource extends Searcher {
}
}
public int docFreq(Term term) {
int df = dfs.dfMap().get(term);
if (df == -1) {
return 1;
// throw new IllegalArgumentException("df for term " + term + " not available");
@Override
public TermStatistics termStatistics(Term term, TermContext context) throws IOException {
TermStatistics termStatistics = dfs.dfMap().get(term);
if (termStatistics == null) {
throw new ElasticSearchIllegalArgumentException("Not distributed term statistics for term: " + term);
}
return df;
return termStatistics;
}
public int[] docFreqs(Term[] terms) {
int[] result = new int[terms.length];
for (int i = 0; i < terms.length; i++) {
result[i] = docFreq(terms[i]);
}
return result;
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
throw new UnsupportedOperationException();
}
public int maxDoc() {
@ -74,15 +76,11 @@ public class CachedDfSource extends Searcher {
return query;
}
public void close() {
throw new UnsupportedOperationException();
}
public Document doc(int i) {
throw new UnsupportedOperationException();
}
public Document doc(int i, FieldSelector fieldSelector) {
public void doc(int docID, StoredFieldVisitor fieldVisitor) throws IOException {
throw new UnsupportedOperationException();
}
@ -90,15 +88,33 @@ public class CachedDfSource extends Searcher {
throw new UnsupportedOperationException();
}
public void search(Weight weight, Filter filter, Collector results) {
@Override
protected void search(List<AtomicReaderContext> leaves, Weight weight, Collector collector) throws IOException {
throw new UnsupportedOperationException();
}
public TopDocs search(Weight weight, Filter filter, int n) {
@Override
protected TopDocs search(Weight weight, ScoreDoc after, int nDocs) throws IOException {
throw new UnsupportedOperationException();
}
public TopFieldDocs search(Weight weight, Filter filter, int n, Sort sort) {
@Override
protected TopDocs search(List<AtomicReaderContext> leaves, Weight weight, ScoreDoc after, int nDocs) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected TopFieldDocs search(Weight weight, int nDocs, Sort sort, boolean doDocScores, boolean doMaxScore) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected TopFieldDocs search(Weight weight, FieldDoc after, int nDocs, Sort sort, boolean fillFields, boolean doDocScores, boolean doMaxScore) throws IOException {
throw new UnsupportedOperationException();
}
@Override
protected TopFieldDocs search(List<AtomicReaderContext> leaves, Weight weight, FieldDoc after, int nDocs, Sort sort, boolean fillFields, boolean doDocScores, boolean doMaxScore) throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -21,7 +21,10 @@ package org.elasticsearch.search.dfs;
import com.google.common.collect.ImmutableMap;
import gnu.trove.set.hash.THashSet;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermContext;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.util.concurrent.ThreadLocals;
import org.elasticsearch.search.SearchParseElement;
import org.elasticsearch.search.SearchPhase;
@ -60,12 +63,21 @@ public class DfsPhase implements SearchPhase {
termsSet.clear();
context.query().extractTerms(termsSet);
Term[] terms = termsSet.toArray(new Term[termsSet.size()]);
int[] freqs = context.searcher().docFreqs(terms);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
for (int i = 0; i < terms.length; i++) {
// LUCENE 4 UPGRADE: cache TermContext?
TermContext termContext = TermContext.build(indexReaderContext, terms[i], false);
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
}
context.dfsResult().termsAndFreqs(terms, freqs);
// TODO: LUCENE 4 UPGRADE - add collection stats for each unique field, for distributed scoring
// context.searcher().collectionStatistics()
context.dfsResult().termsAndFreqs(terms, termStatistics);
context.dfsResult().maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) {
throw new DfsPhaseExecutionException(context, "", e);
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
}
}
}

View File

@ -20,6 +20,8 @@
package org.elasticsearch.search.dfs;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.SearchPhaseResult;
@ -33,13 +35,13 @@ import java.io.IOException;
*/
public class DfsSearchResult extends TransportResponse implements SearchPhaseResult {
private static Term[] EMPTY_TERMS = new Term[0];
private static int[] EMPTY_FREQS = new int[0];
private static final Term[] EMPTY_TERMS = new Term[0];
private static final TermStatistics[] EMPTY_TERM_STATS = new TermStatistics[0];
private SearchShardTarget shardTarget;
private long id;
private Term[] terms;
private int[] freqs;
private TermStatistics[] termStatistics;
private int maxDoc;
public DfsSearchResult() {
@ -73,9 +75,9 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return maxDoc;
}
public DfsSearchResult termsAndFreqs(Term[] terms, int[] freqs) {
public DfsSearchResult termsAndFreqs(Term[] terms, TermStatistics[] termStatistics) {
this.terms = terms;
this.freqs = freqs;
this.termStatistics = termStatistics;
return this;
}
@ -83,8 +85,8 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
return terms;
}
public int[] freqs() {
return freqs;
public TermStatistics[] termStatistics() {
return termStatistics;
}
public static DfsSearchResult readDfsSearchResult(StreamInput in) throws IOException, ClassNotFoundException {
@ -104,16 +106,19 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
} else {
terms = new Term[termsSize];
for (int i = 0; i < terms.length; i++) {
terms[i] = new Term(in.readUTF(), in.readUTF());
terms[i] = new Term(in.readString(), in.readBytesRef());
}
}
int freqsSize = in.readVInt();
if (freqsSize == 0) {
freqs = EMPTY_FREQS;
int termsStatsSize = in.readVInt();
if (termsStatsSize == 0) {
termStatistics = EMPTY_TERM_STATS;
} else {
freqs = new int[freqsSize];
for (int i = 0; i < freqs.length; i++) {
freqs[i] = in.readVInt();
termStatistics = new TermStatistics[termsStatsSize];
for (int i = 0; i < termStatistics.length; i++) {
BytesRef term = terms[i].bytes();
long docFreq = in.readVLong();
long totalTermFreq = in.readVLong();
termStatistics[i] = new TermStatistics(term, docFreq, totalTermFreq);
}
}
maxDoc = in.readVInt();
@ -126,12 +131,13 @@ public class DfsSearchResult extends TransportResponse implements SearchPhaseRes
// shardTarget.writeTo(out);
out.writeVInt(terms.length);
for (Term term : terms) {
out.writeUTF(term.field());
out.writeUTF(term.text());
out.writeString(term.field());
out.writeBytesRef(term.bytes());
}
out.writeVInt(freqs.length);
for (int freq : freqs) {
out.writeVInt(freq);
out.writeVInt(termStatistics.length);
for (TermStatistics termStatistic : termStatistics) {
out.writeVLong(termStatistic.docFreq());
out.writeVLong(termStatistic.totalTermFreq());
}
out.writeVInt(maxDoc);
}