From cbd8cacc50c9ea7e568228e167325829dafcb640 Mon Sep 17 00:00:00 2001 From: Michael McCandless Date: Sun, 23 Mar 2014 11:44:01 +0000 Subject: [PATCH] LUCENE-5545: add SortRescorer and Expression.getRescorer git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1580490 13f79535-47bb-0310-9956-ffa450edef68 --- lucene/CHANGES.txt | 4 + .../apache/lucene/search/QueryRescorer.java | 168 +++++--------- .../apache/lucene/search/SortRescorer.java | 119 ++++++++++ .../lucene/search/TestQueryRescorer.java | 212 ++++++++++++++++++ .../lucene/search/TestSortRescorer.java | 180 +++++++++++++++ .../apache/lucene/expressions/Expression.java | 7 + .../expressions/ExpressionRescorer.java | 134 +++++++++++ .../expressions/TestExpressionRescorer.java | 117 ++++++++++ 8 files changed, 828 insertions(+), 113 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/SortRescorer.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java create mode 100644 lucene/expressions/src/java/org/apache/lucene/expressions/ExpressionRescorer.java create mode 100644 lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d67a0449ed3..9721ec5c4d8 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -127,6 +127,10 @@ New Features first pass search using scores from a more costly second pass search. (Simon Willnauer, Robert Muir, Mike McCandless) +* LUCENE-5545: Add SortRescorer and Expression.getRescorer, to + resort the hits from a first pass search using a Sort or an + Expression. (Simon Willnauer, Robert Muir, Mike McCandless) + API Changes * LUCENE-5454: Add RandomAccessOrds, an optional extension of SortedSetDocValues diff --git a/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java b/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java index b55af993ff5..8403a99fd17 100644 --- a/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java @@ -20,13 +20,9 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; -import java.util.HashMap; -import java.util.Map; +import java.util.List; import org.apache.lucene.index.AtomicReaderContext; -import org.apache.lucene.util.Bits; - -// TODO: we could also have an ExpressionRescorer /** A {@link Rescorer} that uses a provided Query to assign * scores to the first-pass hits. @@ -52,43 +48,65 @@ public abstract class QueryRescorer extends Rescorer { protected abstract float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore); @Override - public TopDocs rescore(IndexSearcher searcher, TopDocs topDocs, int topN) throws IOException { - int[] docIDs = new int[topDocs.scoreDocs.length]; - for(int i=0;i() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + return a.doc - b.doc; + } + }); - TopDocs topDocs2 = searcher.search(query, new OnlyDocIDsFilter(docIDs), topDocs.scoreDocs.length); + List leaves = searcher.getIndexReader().leaves(); - // TODO: we could save small young GC cost here if we - // cloned the incoming ScoreDoc[], sorted that by doc, - // passed that to OnlyDocIDsFilter, sorted 2nd pass - // TopDocs by doc, did a merge sort to combine the - // scores, and finally re-sorted by the combined score, - // but that is sizable added code complexity for minor - // GC savings: - Map newScores = new HashMap(); - for(ScoreDoc sd : topDocs2.scoreDocs) { - newScores.put(sd.doc, sd.score); - } + Weight weight = searcher.createNormalizedWeight(query); - ScoreDoc[] newHits = new ScoreDoc[topDocs.scoreDocs.length]; - for(int i=0;i= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); } - newHits[i] = new ScoreDoc(sd.doc, combinedScore); + + if (readerContext != null) { + // We advanced to another segment: + docBase = readerContext.docBase; + scorer = weight.scorer(readerContext, null); + } + + int targetDoc = docID - docBase; + int actualDoc = scorer.docID(); + if (actualDoc < targetDoc) { + actualDoc = scorer.advance(targetDoc); + } + + if (actualDoc == targetDoc) { + // Query did match this doc: + hit.score = combine(hit.score, true, scorer.score()); + } else { + // Query did not match this doc: + assert actualDoc > targetDoc; + hit.score = combine(hit.score, false, 0.0f); + } + + hitUpto++; } // TODO: we should do a partial sort (of only topN) // instead, but typically the number of hits is // smallish: - Arrays.sort(newHits, + Arrays.sort(hits, new Comparator() { @Override public int compare(ScoreDoc a, ScoreDoc b) { @@ -105,13 +123,13 @@ public abstract class QueryRescorer extends Rescorer { } }); - if (topN < newHits.length) { + if (topN < hits.length) { ScoreDoc[] subset = new ScoreDoc[topN]; - System.arraycopy(newHits, 0, subset, 0, topN); - newHits = subset; + System.arraycopy(hits, 0, subset, 0, topN); + hits = subset; } - return new TopDocs(topDocs.totalHits, newHits, newHits[0].score); + return new TopDocs(firstPassTopDocs.totalHits, hits, hits[0].score); } @Override @@ -159,80 +177,4 @@ public abstract class QueryRescorer extends Rescorer { } }.rescore(searcher, topDocs, topN); } - - /** Filter accepting only the specified docIDs */ - private static class OnlyDocIDsFilter extends Filter { - - private final int[] docIDs; - - /** Sole constructor. */ - public OnlyDocIDsFilter(int[] docIDs) { - this.docIDs = docIDs; - Arrays.sort(docIDs); - } - - @Override - public DocIdSet getDocIdSet(final AtomicReaderContext context, final Bits acceptDocs) throws IOException { - int loc = Arrays.binarySearch(docIDs, context.docBase); - if (loc < 0) { - loc = -loc-1; - } - - final int startLoc = loc; - final int endDoc = context.docBase + context.reader().maxDoc(); - - return new DocIdSet() { - - int pos = startLoc; - - @Override - public DocIdSetIterator iterator() throws IOException { - return new DocIdSetIterator() { - - int docID; - - @Override - public int docID() { - return docID; - } - - @Override - public int nextDoc() { - if (pos == docIDs.length) { - return NO_MORE_DOCS; - } - int docID = docIDs[pos]; - if (docID >= endDoc) { - return NO_MORE_DOCS; - } - pos++; - assert acceptDocs == null || acceptDocs.get(docID-context.docBase); - return docID-context.docBase; - } - - @Override - public long cost() { - // NOTE: not quite right, since this is cost - // across all segments, and we are supposed to - // return cost for just this segment: - return docIDs.length; - } - - @Override - public int advance(int target) { - // TODO: this is a full binary search; we - // could optimize (a bit) by setting lower - // bound to current pos instead: - int loc = Arrays.binarySearch(docIDs, target + context.docBase); - if (loc < 0) { - loc = -loc-1; - } - pos = loc; - return nextDoc(); - } - }; - } - }; - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/SortRescorer.java b/lucene/core/src/java/org/apache/lucene/search/SortRescorer.java new file mode 100644 index 00000000000..1bb21343c40 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SortRescorer.java @@ -0,0 +1,119 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; + +import org.apache.lucene.index.AtomicReaderContext; + +/** + * A {@link Rescorer} that re-sorts according to a provided + * Sort. + */ + +public class SortRescorer extends Rescorer { + + private final Sort sort; + + /** Sole constructor. */ + public SortRescorer(Sort sort) { + this.sort = sort; + } + + @Override + public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException { + + // Copy ScoreDoc[] and sort by ascending docID: + ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone(); + Arrays.sort(hits, + new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + return a.doc - b.doc; + } + }); + + List leaves = searcher.getIndexReader().leaves(); + + TopFieldCollector collector = TopFieldCollector.create(sort, topN, true, true, true, false); + + // Now merge sort docIDs from hits, with reader's leaves: + int hitUpto = 0; + int readerUpto = -1; + int endDoc = 0; + int docBase = 0; + + FakeScorer fakeScorer = new FakeScorer(); + + while (hitUpto < hits.length) { + ScoreDoc hit = hits[hitUpto]; + int docID = hit.doc; + AtomicReaderContext readerContext = null; + while (docID >= endDoc) { + readerUpto++; + readerContext = leaves.get(readerUpto); + endDoc = readerContext.docBase + readerContext.reader().maxDoc(); + } + + if (readerContext != null) { + // We advanced to another segment: + collector.setNextReader(readerContext); + collector.setScorer(fakeScorer); + docBase = readerContext.docBase; + } + + fakeScorer.score = hit.score; + fakeScorer.doc = docID - docBase; + + collector.collect(fakeScorer.doc); + + hitUpto++; + } + + return collector.topDocs(); + } + + @Override + public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException { + TopDocs oneHit = new TopDocs(1, new ScoreDoc[] {new ScoreDoc(docID, firstPassExplanation.getValue())}); + TopDocs hits = rescore(searcher, oneHit, 1); + assert hits.totalHits == 1; + + // TODO: if we could ask the Sort to explain itself then + // we wouldn't need the separate ExpressionRescorer... + Explanation result = new Explanation(0.0f, "sort field values for sort=" + sort.toString()); + + // Add first pass: + Explanation first = new Explanation(firstPassExplanation.getValue(), "first pass score"); + first.addDetail(firstPassExplanation); + result.addDetail(first); + + FieldDoc fieldDoc = (FieldDoc) hits.scoreDocs[0]; + + // Add sort values: + SortField[] sortFields = sort.getSort(); + for(int i=0;i() { + @Override + public int compare(Integer a, Integer b) { + try { + int av = idToNum[Integer.parseInt(r.document(a).get("id"))]; + int bv = idToNum[Integer.parseInt(r.document(b).get("id"))]; + if (av < bv) { + return -reverseInt; + } else if (bv < av) { + return reverseInt; + } else { + // Tie break by docID, ascending + return a - b; + } + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + }); + + boolean fail = false; + for(int i=0;i= context.reader().maxDoc()) { + return NO_MORE_DOCS; + } + return docID; + } + + @Override + public int advance(int target) { + docID = target; + return docID; + } + + @Override + public float score() throws IOException { + int num = idToNum[Integer.parseInt(context.reader().document(docID).get("id"))]; + if (reverse) { + //System.out.println("score doc=" + docID + " num=" + num); + return num; + } else { + //System.out.println("score doc=" + docID + " num=" + -num); + return -num; + } + } + }; + } + + @Override + public Explanation explain(AtomicReaderContext context, int doc) throws IOException { + return null; + } + }; + } + + @Override + public void extractTerms(Set terms) { + } + + @Override + public String toString(String field) { + return "FixedScoreQuery " + idToNum.length + " ids; reverse=" + reverse; + } + + @Override + public boolean equals(Object o) { + if ((o instanceof FixedScoreQuery) == false) { + return false; + } + FixedScoreQuery other = (FixedScoreQuery) o; + return Float.floatToIntBits(getBoost()) == Float.floatToIntBits(other.getBoost()) && + reverse == other.reverse && + Arrays.equals(idToNum, other.idToNum); + } + + @Override + public Query clone() { + return new FixedScoreQuery(idToNum, reverse); + } + + @Override + public int hashCode() { + int PRIME = 31; + int hash = super.hashCode(); + if (reverse) { + hash = PRIME * hash + 3623; + } + hash = PRIME * hash + Arrays.hashCode(idToNum); + return hash; + } + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java b/lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java new file mode 100644 index 00000000000..033b37fe6bb --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSortRescorer.java @@ -0,0 +1,180 @@ +package org.apache.lucene.search; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.TestUtil; + +public class TestSortRescorer extends LuceneTestCase { + IndexSearcher searcher; + DirectoryReader reader; + Directory dir; + + @Override + public void setUp() throws Exception { + super.setUp(); + dir = newDirectory(); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + doc.add(newTextField("body", "some contents and more contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 5)); + iw.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "2", Field.Store.YES)); + doc.add(newTextField("body", "another document with different contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 20)); + iw.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "3", Field.Store.YES)); + doc.add(newTextField("body", "crappy contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 2)); + iw.addDocument(doc); + + reader = iw.getReader(); + searcher = new IndexSearcher(reader); + iw.close(); + } + + @Override + public void tearDown() throws Exception { + reader.close(); + dir.close(); + super.tearDown(); + } + + public void testBasic() throws Exception { + + // create a sort field and sort by it (reverse order) + Query query = new TermQuery(new Term("body", "contents")); + IndexReader r = searcher.getIndexReader(); + + // Just first pass query + TopDocs hits = searcher.search(query, 10); + assertEquals(3, hits.totalHits); + assertEquals("3", r.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", r.document(hits.scoreDocs[2].doc).get("id")); + + // Now, rescore: + Sort sort = new Sort(new SortField("popularity", SortField.Type.INT, true)); + Rescorer rescorer = new SortRescorer(sort); + hits = rescorer.rescore(searcher, hits, 10); + assertEquals(3, hits.totalHits); + assertEquals("2", r.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("3", r.document(hits.scoreDocs[2].doc).get("id")); + + String expl = rescorer.explain(searcher, + searcher.explain(query, hits.scoreDocs[0].doc), + hits.scoreDocs[0].doc).toString(); + + // Confirm the explanation breaks out the individual + // sort fields: + assertTrue(expl.contains("= sort field ! value=20")); + + // Confirm the explanation includes first pass details: + assertTrue(expl.contains("= first pass score")); + assertTrue(expl.contains("body:contents in")); + } + + public void testRandom() throws Exception { + Directory dir = newDirectory(); + int numDocs = atLeast(1000); + RandomIndexWriter w = new RandomIndexWriter(random(), dir); + + final int[] idToNum = new int[numDocs]; + int maxValue = TestUtil.nextInt(random(), 10, 1000000); + for(int i=0;i() { + @Override + public int compare(Integer a, Integer b) { + try { + int av = idToNum[Integer.parseInt(r.document(a).get("id"))]; + int bv = idToNum[Integer.parseInt(r.document(b).get("id"))]; + if (av < bv) { + return -reverseInt; + } else if (bv < av) { + return reverseInt; + } else { + // Tie break by docID + return a - b; + } + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + } + }); + + boolean fail = false; + for(int i=0;i getChildren() { + throw new UnsupportedOperationException(); + } + } + + @Override + public Explanation explain(IndexSearcher searcher, Explanation firstPassExplanation, int docID) throws IOException { + Explanation result = super.explain(searcher, firstPassExplanation, docID); + + List leaves = searcher.getIndexReader().leaves(); + int subReader = ReaderUtil.subIndex(docID, leaves); + AtomicReaderContext readerContext = leaves.get(subReader); + int docIDInSegment = docID - readerContext.docBase; + Map context = new HashMap<>(); + + FakeScorer fakeScorer = new FakeScorer(); + fakeScorer.score = firstPassExplanation.getValue(); + fakeScorer.doc = docIDInSegment; + + context.put("scorer", fakeScorer); + + for(String variable : expression.variables) { + result.addDetail(new Explanation((float) bindings.getValueSource(variable).getValues(context, readerContext).doubleVal(docIDInSegment), + "variable \"" + variable + "\"")); + } + + return result; + } +} diff --git a/lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java b/lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java new file mode 100644 index 00000000000..828f60f1cad --- /dev/null +++ b/lucene/expressions/src/test/org/apache/lucene/expressions/TestExpressionRescorer.java @@ -0,0 +1,117 @@ +package org.apache.lucene.expressions; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.expressions.js.JavascriptCompiler; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Rescorer; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; + +public class TestExpressionRescorer extends LuceneTestCase { + IndexSearcher searcher; + DirectoryReader reader; + Directory dir; + + @Override + public void setUp() throws Exception { + super.setUp(); + dir = newDirectory(); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir); + + Document doc = new Document(); + doc.add(newStringField("id", "1", Field.Store.YES)); + doc.add(newTextField("body", "some contents and more contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 5)); + iw.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "2", Field.Store.YES)); + doc.add(newTextField("body", "another document with different contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 20)); + iw.addDocument(doc); + + doc = new Document(); + doc.add(newStringField("id", "3", Field.Store.YES)); + doc.add(newTextField("body", "crappy contents", Field.Store.NO)); + doc.add(new NumericDocValuesField("popularity", 2)); + iw.addDocument(doc); + + reader = iw.getReader(); + searcher = new IndexSearcher(reader); + iw.close(); + } + + @Override + public void tearDown() throws Exception { + reader.close(); + dir.close(); + super.tearDown(); + } + + public void testBasic() throws Exception { + + // create a sort field and sort by it (reverse order) + Query query = new TermQuery(new Term("body", "contents")); + IndexReader r = searcher.getIndexReader(); + + // Just first pass query + TopDocs hits = searcher.search(query, 10); + assertEquals(3, hits.totalHits); + assertEquals("3", r.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("2", r.document(hits.scoreDocs[2].doc).get("id")); + + // Now, rescore: + + Expression e = JavascriptCompiler.compile("sqrt(_score) + ln(popularity)"); + SimpleBindings bindings = new SimpleBindings(); + bindings.add(new SortField("popularity", SortField.Type.INT)); + bindings.add(new SortField("_score", SortField.Type.SCORE)); + Rescorer rescorer = e.getRescorer(bindings); + + hits = rescorer.rescore(searcher, hits, 10); + assertEquals(3, hits.totalHits); + assertEquals("2", r.document(hits.scoreDocs[0].doc).get("id")); + assertEquals("1", r.document(hits.scoreDocs[1].doc).get("id")); + assertEquals("3", r.document(hits.scoreDocs[2].doc).get("id")); + + String expl = rescorer.explain(searcher, + searcher.explain(query, hits.scoreDocs[0].doc), + hits.scoreDocs[0].doc).toString(); + + // Confirm the explanation breaks out the individual + // variables: + assertTrue(expl.contains("= variable \"popularity\"")); + + // Confirm the explanation includes first pass details: + assertTrue(expl.contains("= first pass score")); + assertTrue(expl.contains("body:contents in")); + } +}