mirror of https://github.com/apache/lucene.git
LUCENE-5545: add SortRescorer and Expression.getRescorer
git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1580490 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
parent
68825572d2
commit
cbd8cacc50
|
@ -127,6 +127,10 @@ New Features
|
|||
first pass search using scores from a more costly second pass
|
||||
search. (Simon Willnauer, Robert Muir, Mike McCandless)
|
||||
|
||||
* LUCENE-5545: Add SortRescorer and Expression.getRescorer, to
|
||||
resort the hits from a first pass search using a Sort or an
|
||||
Expression. (Simon Willnauer, Robert Muir, Mike McCandless)
|
||||
|
||||
API Changes
|
||||
|
||||
* LUCENE-5454: Add RandomAccessOrds, an optional extension of SortedSetDocValues
|
||||
|
|
|
@ -20,13 +20,9 @@ package org.apache.lucene.search;
|
|||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.index.AtomicReaderContext;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
// TODO: we could also have an ExpressionRescorer
|
||||
|
||||
/** A {@link Rescorer} that uses a provided Query to assign
|
||||
* scores to the first-pass hits.
|
||||
|
@ -52,43 +48,65 @@ public abstract class QueryRescorer extends Rescorer {
|
|||
protected abstract float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore);
|
||||
|
||||
@Override
|
||||
public TopDocs rescore(IndexSearcher searcher, TopDocs topDocs, int topN) throws IOException {
|
||||
int[] docIDs = new int[topDocs.scoreDocs.length];
|
||||
for(int i=0;i<docIDs.length;i++) {
|
||||
docIDs[i] = topDocs.scoreDocs[i].doc;
|
||||
}
|
||||
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException {
|
||||
ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
|
||||
Arrays.sort(hits,
|
||||
new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
return a.doc - b.doc;
|
||||
}
|
||||
});
|
||||
|
||||
TopDocs topDocs2 = searcher.search(query, new OnlyDocIDsFilter(docIDs), topDocs.scoreDocs.length);
|
||||
List<AtomicReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||
|
||||
// TODO: we could save small young GC cost here if we
|
||||
// cloned the incoming ScoreDoc[], sorted that by doc,
|
||||
// passed that to OnlyDocIDsFilter, sorted 2nd pass
|
||||
// TopDocs by doc, did a merge sort to combine the
|
||||
// scores, and finally re-sorted by the combined score,
|
||||
// but that is sizable added code complexity for minor
|
||||
// GC savings:
|
||||
Map<Integer,Float> newScores = new HashMap<Integer,Float>();
|
||||
for(ScoreDoc sd : topDocs2.scoreDocs) {
|
||||
newScores.put(sd.doc, sd.score);
|
||||
}
|
||||
Weight weight = searcher.createNormalizedWeight(query);
|
||||
|
||||
ScoreDoc[] newHits = new ScoreDoc[topDocs.scoreDocs.length];
|
||||
for(int i=0;i<topDocs.scoreDocs.length;i++) {
|
||||
ScoreDoc sd = topDocs.scoreDocs[i];
|
||||
Float newScore = newScores.get(sd.doc);
|
||||
float combinedScore;
|
||||
if (newScore == null) {
|
||||
combinedScore = combine(sd.score, false, 0.0f);
|
||||
} else {
|
||||
combinedScore = combine(sd.score, true, newScore.floatValue());
|
||||
// Now merge sort docIDs from hits, with reader's leaves:
|
||||
int hitUpto = 0;
|
||||
int readerUpto = -1;
|
||||
int endDoc = 0;
|
||||
int docBase = 0;
|
||||
Scorer scorer = null;
|
||||
|
||||
while (hitUpto < hits.length) {
|
||||
ScoreDoc hit = hits[hitUpto];
|
||||
int docID = hit.doc;
|
||||
AtomicReaderContext readerContext = null;
|
||||
while (docID >= endDoc) {
|
||||
readerUpto++;
|
||||
readerContext = leaves.get(readerUpto);
|
||||
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
|
||||
}
|
||||
newHits[i] = new ScoreDoc(sd.doc, combinedScore);
|
||||
|
||||
if (readerContext != null) {
|
||||
// We advanced to another segment:
|
||||
docBase = readerContext.docBase;
|
||||
scorer = weight.scorer(readerContext, null);
|
||||
}
|
||||
|
||||
int targetDoc = docID - docBase;
|
||||
int actualDoc = scorer.docID();
|
||||
if (actualDoc < targetDoc) {
|
||||
actualDoc = scorer.advance(targetDoc);
|
||||
}
|
||||
|
||||
if (actualDoc == targetDoc) {
|
||||
// Query did match this doc:
|
||||
hit.score = combine(hit.score, true, scorer.score());
|
||||
} else {
|
||||
// Query did not match this doc:
|
||||
assert actualDoc > targetDoc;
|
||||
hit.score = combine(hit.score, false, 0.0f);
|
||||
}
|
||||
|
||||
hitUpto++;
|
||||
}
|
||||
|
||||
// TODO: we should do a partial sort (of only topN)
|
||||
// instead, but typically the number of hits is
|
||||
// smallish:
|
||||
Arrays.sort(newHits,
|
||||
Arrays.sort(hits,
|
||||
new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
|
@ -105,13 +123,13 @@ public abstract class QueryRescorer extends Rescorer {
|
|||
}
|
||||
});
|
||||
|
||||
if (topN < newHits.length) {
|
||||
if (topN < hits.length) {
|
||||
ScoreDoc[] subset = new ScoreDoc[topN];
|
||||
System.arraycopy(newHits, 0, subset, 0, topN);
|
||||
newHits = subset;
|
||||
System.arraycopy(hits, 0, subset, 0, topN);
|
||||
hits = subset;
|
||||
}
|
||||
|
||||
return new TopDocs(topDocs.totalHits, newHits, newHits[0].score);
|
||||
return new TopDocs(firstPassTopDocs.totalHits, hits, hits[0].score);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -159,80 +177,4 @@ public abstract class QueryRescorer extends Rescorer {
|
|||
}
|
||||
}.rescore(searcher, topDocs, topN);
|
||||
}
|
||||
|
||||
/** Filter accepting only the specified docIDs */
|
||||
private static class OnlyDocIDsFilter extends Filter {
|
||||
|
||||
private final int[] docIDs;
|
||||
|
||||
/** Sole constructor. */
|
||||
public OnlyDocIDsFilter(int[] docIDs) {
|
||||
this.docIDs = docIDs;
|
||||
Arrays.sort(docIDs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSet getDocIdSet(final AtomicReaderContext context, final Bits acceptDocs) throws IOException {
|
||||
int loc = Arrays.binarySearch(docIDs, context.docBase);
|
||||
if (loc < 0) {
|
||||
loc = -loc-1;
|
||||
}
|
||||
|
||||
final int startLoc = loc;
|
||||
final int endDoc = context.docBase + context.reader().maxDoc();
|
||||
|
||||
return new DocIdSet() {
|
||||
|
||||
int pos = startLoc;
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() throws IOException {
|
||||
return new DocIdSetIterator() {
|
||||
|
||||
int docID;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (pos == docIDs.length) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
int docID = docIDs[pos];
|
||||
if (docID >= endDoc) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
pos++;
|
||||
assert acceptDocs == null || acceptDocs.get(docID-context.docBase);
|
||||
return docID-context.docBase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
// NOTE: not quite right, since this is cost
|
||||
// across all segments, and we are supposed to
|
||||
// return cost for just this segment:
|
||||
return docIDs.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
// TODO: this is a full binary search; we
|
||||
// could optimize (a bit) by setting lower
|
||||
// bound to current pos instead:
|
||||
int loc = Arrays.binarySearch(docIDs, target + context.docBase);
|
||||
if (loc < 0) {
|
||||
loc = -loc-1;
|
||||
}
|
||||
pos = loc;
|
||||
return nextDoc();
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.lucene.index.AtomicReaderContext;
|
||||
|
||||
/**
|
||||
* A {@link Rescorer} that re-sorts according to a provided
|
||||
* Sort.
|
||||
*/
|
||||
|
||||
public class SortRescorer extends Rescorer {
|
||||
|
||||
private final Sort sort;
|
||||
|
||||
/** Sole constructor. */
|
||||
public SortRescorer(Sort sort) {
|
||||
this.sort = sort;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException {
|
||||
|
||||
// Copy ScoreDoc[] and sort by ascending docID:
|
||||
ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
|
||||
Arrays.sort(hits,
|
||||
new Comparator<ScoreDoc>() {
|
||||
@Override
|
||||
public int compare(ScoreDoc a, ScoreDoc b) {
|
||||
return a.doc - b.doc;
|
||||
}
|
||||
});
|
||||
|
||||
List<AtomicReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||
|
||||
TopFieldCollector collector = TopFieldCollector.create(sort, topN, true, true, true, false);
|
||||
|
||||
// Now merge sort docIDs from hits, with reader's leaves:
|
||||
int hitUpto = 0;
|
||||
int readerUpto = -1;
|
||||
int endDoc = 0;
|
||||
int docBase = 0;
|
||||
|
||||
FakeScorer fakeScorer = new FakeScorer();
|
||||
|
||||
while (hitUpto < hits.length) {
|
||||
ScoreDoc hit = hits[hitUpto];
|
||||
int docID = hit.doc;
|
||||
AtomicReaderContext readerContext = null;
|
||||
while (docID >= endDoc) {
|
||||
readerUpto++;
|
||||
readerContext = leaves.get(readerUpto);
|
||||
endDoc = readerContext.docBase + readerContext.reader().maxDoc();
|
||||
}
|
||||
|
||||
if (readerContext != null) {
|
||||
// We advanced to another segment:
|
||||
collector.setNextReader(readerContext);
|
||||
collector.setScorer(fakeScorer);
|
||||
docBase = readerContext.docBase;
|
||||
}
|
||||
|
||||
fakeScorer.score = hit.score;
|
||||
fakeScorer.doc = docID - docBase;
|
||||
|
||||
collector.collect(fakeScorer.doc);
|
||||
|
||||
hitUpto++;
|
||||
}
|
||||
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException {
|
||||
TopDocs oneHit = new TopDocs(1, new ScoreDoc[] {new ScoreDoc(docID, firstPassExplanation.getValue())});
|
||||
TopDocs hits = rescore(searcher, oneHit, 1);
|
||||
assert hits.totalHits == 1;
|
||||
|
||||
// TODO: if we could ask the Sort to explain itself then
|
||||
// we wouldn't need the separate ExpressionRescorer...
|
||||
Explanation result = new Explanation(0.0f, "sort field values for sort=" + sort.toString());
|
||||
|
||||
// Add first pass:
|
||||
Explanation first = new Explanation(firstPassExplanation.getValue(), "first pass score");
|
||||
first.addDetail(firstPassExplanation);
|
||||
result.addDetail(first);
|
||||
|
||||
FieldDoc fieldDoc = (FieldDoc) hits.scoreDocs[0];
|
||||
|
||||
// Add sort values:
|
||||
SortField[] sortFields = sort.getSort();
|
||||
for(int i=0;i<sortFields.length;i++) {
|
||||
result.addDetail(new Explanation(0.0f, "sort field " + sortFields[i].toString() + " value=" + fieldDoc.fields[i]));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -17,8 +17,15 @@ package org.apache.lucene.search;
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Set;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.index.AtomicReaderContext;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
|
@ -28,7 +35,9 @@ import org.apache.lucene.search.spans.SpanNearQuery;
|
|||
import org.apache.lucene.search.spans.SpanQuery;
|
||||
import org.apache.lucene.search.spans.SpanTermQuery;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
|
||||
public class TestQueryRescorer extends LuceneTestCase {
|
||||
|
||||
|
@ -62,6 +71,7 @@ public class TestQueryRescorer extends LuceneTestCase {
|
|||
bq.add(new TermQuery(new Term("field", "wizard")), Occur.SHOULD);
|
||||
bq.add(new TermQuery(new Term("field", "oz")), Occur.SHOULD);
|
||||
IndexSearcher searcher = getSearcher(r);
|
||||
searcher.setSimilarity(new DefaultSimilarity());
|
||||
|
||||
TopDocs hits = searcher.search(bq, 10);
|
||||
assertEquals(2, hits.totalHits);
|
||||
|
@ -283,4 +293,206 @@ public class TestQueryRescorer extends LuceneTestCase {
|
|||
r.close();
|
||||
dir.close();
|
||||
}
|
||||
|
||||
public void testRandom() throws Exception {
|
||||
Directory dir = newDirectory();
|
||||
int numDocs = atLeast(1000);
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
|
||||
|
||||
final int[] idToNum = new int[numDocs];
|
||||
int maxValue = TestUtil.nextInt(random(), 10, 1000000);
|
||||
for(int i=0;i<numDocs;i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(newStringField("id", ""+i, Field.Store.YES));
|
||||
int numTokens = TestUtil.nextInt(random(), 1, 10);
|
||||
StringBuilder b = new StringBuilder();
|
||||
for(int j=0;j<numTokens;j++) {
|
||||
b.append("a ");
|
||||
}
|
||||
doc.add(newTextField("field", b.toString(), Field.Store.NO));
|
||||
idToNum[i] = random().nextInt(maxValue);
|
||||
doc.add(new NumericDocValuesField("num", idToNum[i]));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
final IndexReader r = w.getReader();
|
||||
w.close();
|
||||
|
||||
IndexSearcher s = newSearcher(r);
|
||||
int numHits = TestUtil.nextInt(random(), 1, numDocs);
|
||||
boolean reverse = random().nextBoolean();
|
||||
|
||||
//System.out.println("numHits=" + numHits + " reverse=" + reverse);
|
||||
TopDocs hits = s.search(new TermQuery(new Term("field", "a")), numHits);
|
||||
|
||||
TopDocs hits2 = new QueryRescorer(new FixedScoreQuery(idToNum, reverse)) {
|
||||
@Override
|
||||
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
|
||||
return secondPassScore;
|
||||
}
|
||||
}.rescore(s, hits, numHits);
|
||||
|
||||
Integer[] expected = new Integer[numHits];
|
||||
for(int i=0;i<numHits;i++) {
|
||||
expected[i] = hits.scoreDocs[i].doc;
|
||||
}
|
||||
|
||||
final int reverseInt = reverse ? -1 : 1;
|
||||
|
||||
Arrays.sort(expected,
|
||||
new Comparator<Integer>() {
|
||||
@Override
|
||||
public int compare(Integer a, Integer b) {
|
||||
try {
|
||||
int av = idToNum[Integer.parseInt(r.document(a).get("id"))];
|
||||
int bv = idToNum[Integer.parseInt(r.document(b).get("id"))];
|
||||
if (av < bv) {
|
||||
return -reverseInt;
|
||||
} else if (bv < av) {
|
||||
return reverseInt;
|
||||
} else {
|
||||
// Tie break by docID, ascending
|
||||
return a - b;
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
throw new RuntimeException(ioe);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
boolean fail = false;
|
||||
for(int i=0;i<numHits;i++) {
|
||||
//System.out.println("expected=" + expected[i] + " vs " + hits2.scoreDocs[i].doc + " v=" + idToNum[Integer.parseInt(r.document(expected[i]).get("id"))]);
|
||||
if (expected[i].intValue() != hits2.scoreDocs[i].doc) {
|
||||
//System.out.println(" diff!");
|
||||
fail = true;
|
||||
}
|
||||
}
|
||||
assertFalse(fail);
|
||||
|
||||
r.close();
|
||||
dir.close();
|
||||
}
|
||||
|
||||
/** Just assigns score == idToNum[doc("id")] for each doc. */
|
||||
private static class FixedScoreQuery extends Query {
|
||||
private final int[] idToNum;
|
||||
private final boolean reverse;
|
||||
|
||||
public FixedScoreQuery(int[] idToNum, boolean reverse) {
|
||||
this.idToNum = idToNum;
|
||||
this.reverse = reverse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher) throws IOException {
|
||||
|
||||
return new Weight() {
|
||||
|
||||
@Override
|
||||
public Query getQuery() {
|
||||
return FixedScoreQuery.this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getValueForNormalization() {
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void normalize(float queryNorm, float topLevelBoost) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(final AtomicReaderContext context, Bits acceptDocs) throws IOException {
|
||||
|
||||
return new Scorer(null) {
|
||||
int docID = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int freq() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
docID++;
|
||||
if (docID >= context.reader().maxDoc()) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
docID = target;
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
int num = idToNum[Integer.parseInt(context.reader().document(docID).get("id"))];
|
||||
if (reverse) {
|
||||
//System.out.println("score doc=" + docID + " num=" + num);
|
||||
return num;
|
||||
} else {
|
||||
//System.out.println("score doc=" + docID + " num=" + -num);
|
||||
return -num;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(AtomicReaderContext context, int doc) throws IOException {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractTerms(Set<Term> terms) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "FixedScoreQuery " + idToNum.length + " ids; reverse=" + reverse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if ((o instanceof FixedScoreQuery) == false) {
|
||||
return false;
|
||||
}
|
||||
FixedScoreQuery other = (FixedScoreQuery) o;
|
||||
return Float.floatToIntBits(getBoost()) == Float.floatToIntBits(other.getBoost()) &&
|
||||
reverse == other.reverse &&
|
||||
Arrays.equals(idToNum, other.idToNum);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query clone() {
|
||||
return new FixedScoreQuery(idToNum, reverse);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int PRIME = 31;
|
||||
int hash = super.hashCode();
|
||||
if (reverse) {
|
||||
hash = PRIME * hash + 3623;
|
||||
}
|
||||
hash = PRIME * hash + Arrays.hashCode(idToNum);
|
||||
return hash;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.TestUtil;
|
||||
|
||||
public class TestSortRescorer extends LuceneTestCase {
|
||||
IndexSearcher searcher;
|
||||
DirectoryReader reader;
|
||||
Directory dir;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir);
|
||||
|
||||
Document doc = new Document();
|
||||
doc.add(newStringField("id", "1", Field.Store.YES));
|
||||
doc.add(newTextField("body", "some contents and more contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 5));
|
||||
iw.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(newStringField("id", "2", Field.Store.YES));
|
||||
doc.add(newTextField("body", "another document with different contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 20));
|
||||
iw.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(newStringField("id", "3", Field.Store.YES));
|
||||
doc.add(newTextField("body", "crappy contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 2));
|
||||
iw.addDocument(doc);
|
||||
|
||||
reader = iw.getReader();
|
||||
searcher = new IndexSearcher(reader);
|
||||
iw.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void tearDown() throws Exception {
|
||||
reader.close();
|
||||
dir.close();
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
public void testBasic() throws Exception {
|
||||
|
||||
// create a sort field and sort by it (reverse order)
|
||||
Query query = new TermQuery(new Term("body", "contents"));
|
||||
IndexReader r = searcher.getIndexReader();
|
||||
|
||||
// Just first pass query
|
||||
TopDocs hits = searcher.search(query, 10);
|
||||
assertEquals(3, hits.totalHits);
|
||||
assertEquals("3", r.document(hits.scoreDocs[0].doc).get("id"));
|
||||
assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id"));
|
||||
assertEquals("2", r.document(hits.scoreDocs[2].doc).get("id"));
|
||||
|
||||
// Now, rescore:
|
||||
Sort sort = new Sort(new SortField("popularity", SortField.Type.INT, true));
|
||||
Rescorer rescorer = new SortRescorer(sort);
|
||||
hits = rescorer.rescore(searcher, hits, 10);
|
||||
assertEquals(3, hits.totalHits);
|
||||
assertEquals("2", r.document(hits.scoreDocs[0].doc).get("id"));
|
||||
assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id"));
|
||||
assertEquals("3", r.document(hits.scoreDocs[2].doc).get("id"));
|
||||
|
||||
String expl = rescorer.explain(searcher,
|
||||
searcher.explain(query, hits.scoreDocs[0].doc),
|
||||
hits.scoreDocs[0].doc).toString();
|
||||
|
||||
// Confirm the explanation breaks out the individual
|
||||
// sort fields:
|
||||
assertTrue(expl.contains("= sort field <int: \"popularity\">! value=20"));
|
||||
|
||||
// Confirm the explanation includes first pass details:
|
||||
assertTrue(expl.contains("= first pass score"));
|
||||
assertTrue(expl.contains("body:contents in"));
|
||||
}
|
||||
|
||||
public void testRandom() throws Exception {
|
||||
Directory dir = newDirectory();
|
||||
int numDocs = atLeast(1000);
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
|
||||
|
||||
final int[] idToNum = new int[numDocs];
|
||||
int maxValue = TestUtil.nextInt(random(), 10, 1000000);
|
||||
for(int i=0;i<numDocs;i++) {
|
||||
Document doc = new Document();
|
||||
doc.add(newStringField("id", ""+i, Field.Store.YES));
|
||||
int numTokens = TestUtil.nextInt(random(), 1, 10);
|
||||
StringBuilder b = new StringBuilder();
|
||||
for(int j=0;j<numTokens;j++) {
|
||||
b.append("a ");
|
||||
}
|
||||
doc.add(newTextField("field", b.toString(), Field.Store.NO));
|
||||
idToNum[i] = random().nextInt(maxValue);
|
||||
doc.add(new NumericDocValuesField("num", idToNum[i]));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
final IndexReader r = w.getReader();
|
||||
w.close();
|
||||
|
||||
IndexSearcher s = newSearcher(r);
|
||||
int numHits = TestUtil.nextInt(random(), 1, numDocs);
|
||||
boolean reverse = random().nextBoolean();
|
||||
|
||||
TopDocs hits = s.search(new TermQuery(new Term("field", "a")), numHits);
|
||||
|
||||
Rescorer rescorer = new SortRescorer(new Sort(new SortField("num", SortField.Type.INT, reverse)));
|
||||
TopDocs hits2 = rescorer.rescore(s, hits, numHits);
|
||||
|
||||
Integer[] expected = new Integer[numHits];
|
||||
for(int i=0;i<numHits;i++) {
|
||||
expected[i] = hits.scoreDocs[i].doc;
|
||||
}
|
||||
|
||||
final int reverseInt = reverse ? -1 : 1;
|
||||
|
||||
Arrays.sort(expected,
|
||||
new Comparator<Integer>() {
|
||||
@Override
|
||||
public int compare(Integer a, Integer b) {
|
||||
try {
|
||||
int av = idToNum[Integer.parseInt(r.document(a).get("id"))];
|
||||
int bv = idToNum[Integer.parseInt(r.document(b).get("id"))];
|
||||
if (av < bv) {
|
||||
return -reverseInt;
|
||||
} else if (bv < av) {
|
||||
return reverseInt;
|
||||
} else {
|
||||
// Tie break by docID
|
||||
return a - b;
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
throw new RuntimeException(ioe);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
boolean fail = false;
|
||||
for(int i=0;i<numHits;i++) {
|
||||
fail |= expected[i].intValue() != hits2.scoreDocs[i].doc;
|
||||
}
|
||||
assertFalse(fail);
|
||||
|
||||
r.close();
|
||||
dir.close();
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.expressions;
|
|||
import org.apache.lucene.expressions.js.JavascriptCompiler; // javadocs
|
||||
import org.apache.lucene.queries.function.FunctionValues;
|
||||
import org.apache.lucene.queries.function.ValueSource;
|
||||
import org.apache.lucene.search.Rescorer;
|
||||
import org.apache.lucene.search.SortField;
|
||||
|
||||
/**
|
||||
|
@ -83,4 +84,10 @@ public abstract class Expression {
|
|||
public SortField getSortField(Bindings bindings, boolean reverse) {
|
||||
return getValueSource(bindings).getSortField(reverse);
|
||||
}
|
||||
|
||||
/** Get a {@link Rescorer}, to rescore first-pass hits
|
||||
* using this expression. */
|
||||
public Rescorer getRescorer(Bindings bindings) {
|
||||
return new ExpressionRescorer(this, bindings);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
package org.apache.lucene.expressions;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.apache.lucene.index.AtomicReaderContext;
|
||||
import org.apache.lucene.index.ReaderUtil;
|
||||
import org.apache.lucene.queries.function.ValueSource;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Rescorer;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortRescorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
|
||||
/**
|
||||
* A {@link Rescorer} that uses an expression to re-score
|
||||
* first pass hits. Functionally this is the same as {@link
|
||||
* SortRescorer} (if you build the {@link Sort} using {@link
|
||||
* Expression#getSortField}), except for the explain method
|
||||
* which gives more detail by showing the value of each
|
||||
* variable.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
|
||||
class ExpressionRescorer extends SortRescorer {
|
||||
|
||||
private final Expression expression;
|
||||
private final Bindings bindings;
|
||||
|
||||
/** Uses the provided {@link ValueSource} to assign second
|
||||
* pass scores. */
|
||||
public ExpressionRescorer(Expression expression, Bindings bindings) {
|
||||
super(new Sort(expression.getSortField(bindings, true)));
|
||||
this.expression = expression;
|
||||
this.bindings = bindings;
|
||||
}
|
||||
|
||||
private static class FakeScorer extends Scorer {
|
||||
float score;
|
||||
int doc = -1;
|
||||
int freq = 1;
|
||||
|
||||
public FakeScorer() {
|
||||
super(null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException("FakeScorer doesn't support advance(int)");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int freq() {
|
||||
return freq;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
throw new UnsupportedOperationException("FakeScorer doesn't support nextDoc()");
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight getWeight() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<ChildScorer> getChildren() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException {
|
||||
Explanation result = super.explain(searcher, firstPassExplanation, docID);
|
||||
|
||||
List<AtomicReaderContext> leaves = searcher.getIndexReader().leaves();
|
||||
int subReader = ReaderUtil.subIndex(docID, leaves);
|
||||
AtomicReaderContext readerContext = leaves.get(subReader);
|
||||
int docIDInSegment = docID - readerContext.docBase;
|
||||
Map<String,Object> context = new HashMap<>();
|
||||
|
||||
FakeScorer fakeScorer = new FakeScorer();
|
||||
fakeScorer.score = firstPassExplanation.getValue();
|
||||
fakeScorer.doc = docIDInSegment;
|
||||
|
||||
context.put("scorer", fakeScorer);
|
||||
|
||||
for(String variable : expression.variables) {
|
||||
result.addDetail(new Explanation((float) bindings.getValueSource(variable).getValues(context, readerContext).doubleVal(docIDInSegment),
|
||||
"variable \"" + variable + "\""));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
package org.apache.lucene.expressions;
|
||||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.expressions.js.JavascriptCompiler;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.Rescorer;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TermQuery;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
|
||||
public class TestExpressionRescorer extends LuceneTestCase {
|
||||
IndexSearcher searcher;
|
||||
DirectoryReader reader;
|
||||
Directory dir;
|
||||
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
super.setUp();
|
||||
dir = newDirectory();
|
||||
RandomIndexWriter iw = new RandomIndexWriter(random(), dir);
|
||||
|
||||
Document doc = new Document();
|
||||
doc.add(newStringField("id", "1", Field.Store.YES));
|
||||
doc.add(newTextField("body", "some contents and more contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 5));
|
||||
iw.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(newStringField("id", "2", Field.Store.YES));
|
||||
doc.add(newTextField("body", "another document with different contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 20));
|
||||
iw.addDocument(doc);
|
||||
|
||||
doc = new Document();
|
||||
doc.add(newStringField("id", "3", Field.Store.YES));
|
||||
doc.add(newTextField("body", "crappy contents", Field.Store.NO));
|
||||
doc.add(new NumericDocValuesField("popularity", 2));
|
||||
iw.addDocument(doc);
|
||||
|
||||
reader = iw.getReader();
|
||||
searcher = new IndexSearcher(reader);
|
||||
iw.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void tearDown() throws Exception {
|
||||
reader.close();
|
||||
dir.close();
|
||||
super.tearDown();
|
||||
}
|
||||
|
||||
public void testBasic() throws Exception {
|
||||
|
||||
// create a sort field and sort by it (reverse order)
|
||||
Query query = new TermQuery(new Term("body", "contents"));
|
||||
IndexReader r = searcher.getIndexReader();
|
||||
|
||||
// Just first pass query
|
||||
TopDocs hits = searcher.search(query, 10);
|
||||
assertEquals(3, hits.totalHits);
|
||||
assertEquals("3", r.document(hits.scoreDocs[0].doc).get("id"));
|
||||
assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id"));
|
||||
assertEquals("2", r.document(hits.scoreDocs[2].doc).get("id"));
|
||||
|
||||
// Now, rescore:
|
||||
|
||||
Expression e = JavascriptCompiler.compile("sqrt(_score) + ln(popularity)");
|
||||
SimpleBindings bindings = new SimpleBindings();
|
||||
bindings.add(new SortField("popularity", SortField.Type.INT));
|
||||
bindings.add(new SortField("_score", SortField.Type.SCORE));
|
||||
Rescorer rescorer = e.getRescorer(bindings);
|
||||
|
||||
hits = rescorer.rescore(searcher, hits, 10);
|
||||
assertEquals(3, hits.totalHits);
|
||||
assertEquals("2", r.document(hits.scoreDocs[0].doc).get("id"));
|
||||
assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id"));
|
||||
assertEquals("3", r.document(hits.scoreDocs[2].doc).get("id"));
|
||||
|
||||
String expl = rescorer.explain(searcher,
|
||||
searcher.explain(query, hits.scoreDocs[0].doc),
|
||||
hits.scoreDocs[0].doc).toString();
|
||||
|
||||
// Confirm the explanation breaks out the individual
|
||||
// variables:
|
||||
assertTrue(expl.contains("= variable \"popularity\""));
|
||||
|
||||
// Confirm the explanation includes first pass details:
|
||||
assertTrue(expl.contains("= first pass score"));
|
||||
assertTrue(expl.contains("body:contents in"));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue