LUCENE-7461: Refactor doc values queries to leverage the new iterator API.

This commit is contained in:
Adrien Grand 2016-11-15 14:55:46 +01:00
parent 0325722e67
commit 212b1d8462
7 changed files with 105 additions and 213 deletions

View File

@ -25,7 +25,6 @@ import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.LongBitSet;
/**
@ -74,9 +73,9 @@ public final class DocValuesRewriteMethod extends MultiTermQuery.RewriteMethod {
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new RandomAccessWeight(this, boost) {
return new ConstantScoreWeight(this, boost) {
@Override
protected Bits getMatchingDocs(LeafReaderContext context) throws IOException {
public Scorer scorer(LeafReaderContext context) throws IOException {
final SortedSetDocValues fcsi = DocValues.getSortedSet(context.reader(), query.field);
TermsEnum termsEnum = query.getTermsEnum(new Terms() {
@ -141,38 +140,28 @@ public final class DocValuesRewriteMethod extends MultiTermQuery.RewriteMethod {
}
} while (termsEnum.next() != null);
return new Bits() {
return new ConstantScoreScorer(this, score(), new TwoPhaseIterator(fcsi) {
@Override
public boolean get(int doc) {
try {
if (doc > fcsi.docID()) {
fcsi.advance(doc);
public boolean matches() throws IOException {
for (long ord = fcsi.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = fcsi.nextOrd()) {
if (termSet.get(ord)) {
return true;
}
if (doc == fcsi.docID()) {
for (long ord = fcsi.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = fcsi.nextOrd()) {
if (termSet.get(ord)) {
return true;
}
}
}
return false;
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return false;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 3; // lookup in a bitset
}
};
});
}
};
}
}
@Override
public boolean equals(Object other) {
return other != null &&

View File

@ -1,76 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.Bits.MatchNoBits;
/**
* Base class to build {@link Weight}s that are based on random-access
* structures such as live docs or doc values. Such weights return a
* {@link Scorer} which consists of an approximation that matches
* everything, and a confirmation phase that first checks live docs and
* then the {@link Bits} returned by {@link #getMatchingDocs(LeafReaderContext)}.
* @lucene.internal
*/
public abstract class RandomAccessWeight extends ConstantScoreWeight {
/** Sole constructor. */
protected RandomAccessWeight(Query query, float boost) {
super(query, boost);
}
/**
* Return a {@link Bits} instance representing documents that match this
* weight on the given context. A return value of {@code null} indicates
* that no documents matched.
* Note: it is not needed to care about live docs as they will be checked
* before the returned bits.
*/
protected abstract Bits getMatchingDocs(LeafReaderContext context) throws IOException;
@Override
public final Scorer scorer(LeafReaderContext context) throws IOException {
final Bits matchingDocs = getMatchingDocs(context);
if (matchingDocs == null || matchingDocs instanceof MatchNoBits) {
return null;
}
final DocIdSetIterator approximation = DocIdSetIterator.all(context.reader().maxDoc());
final TwoPhaseIterator twoPhase = new TwoPhaseIterator(approximation) {
@Override
public boolean matches() throws IOException {
final int doc = approximation.docID();
return matchingDocs.get(doc);
}
@Override
public float matchCost() {
return 10; // TODO: use some cost of matchingDocs
}
};
return new ConstantScoreScorer(this, score(), twoPhase);
}
}

View File

@ -46,19 +46,22 @@ import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.RandomAccessWeight;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.SimpleCollector;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InPlaceMergeSorter;
@ -651,27 +654,26 @@ public class TestDrillSideways extends FacetTestCase {
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new RandomAccessWeight(this, boost) {
return new ConstantScoreWeight(this, boost) {
@Override
protected Bits getMatchingDocs(final LeafReaderContext context) throws IOException {
return new Bits() {
public Scorer scorer(LeafReaderContext context) throws IOException {
DocIdSetIterator approximation = DocIdSetIterator.all(context.reader().maxDoc());
return new ConstantScoreScorer(this, score(), new TwoPhaseIterator(approximation) {
@Override
public boolean get(int docID) {
try {
return (Integer.parseInt(context.reader().document(docID).get("id")) & 1) == 0;
} catch (NumberFormatException | IOException e) {
throw new RuntimeException(e);
}
public boolean matches() throws IOException {
int docID = approximation.docID();
return (Integer.parseInt(context.reader().document(docID).get("id")) & 1) == 0;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 1000f;
}
};
});
}
};
}

View File

@ -27,7 +27,6 @@ import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.Bits;
/**
* Like {@link DocValuesTermsQuery}, but this query only
@ -96,38 +95,29 @@ public class DocValuesNumbersQuery extends Query {
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new RandomAccessWeight(this, boost) {
return new ConstantScoreWeight(this, boost) {
@Override
protected Bits getMatchingDocs(LeafReaderContext context) throws IOException {
final SortedNumericDocValues values = DocValues.getSortedNumeric(context.reader(), field);
return new Bits() {
public Scorer scorer(LeafReaderContext context) throws IOException {
final SortedNumericDocValues values = DocValues.getSortedNumeric(context.reader(), field);
return new ConstantScoreScorer(this, score(), new TwoPhaseIterator(values) {
@Override
public boolean get(int doc) {
try {
if (doc > values.docID()) {
values.advance(doc);
}
if (doc == values.docID()) {
int count = values.docValueCount();
for(int i=0;i<count;i++) {
if (numbers.contains(values.nextValue())) {
return true;
}
}
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return false;
@Override
public boolean matches() throws IOException {
int count = values.docValueCount();
for(int i=0;i<count;i++) {
if (numbers.contains(values.nextValue())) {
return true;
}
}
return false;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 5; // lookup in the set
}
};
});
}
};
}

View File

@ -25,7 +25,6 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
/**
@ -143,10 +142,19 @@ public final class DocValuesRangeQuery extends Query {
if (lowerVal == null && upperVal == null) {
throw new IllegalStateException("Both min and max values must not be null, call rewrite first");
}
return new RandomAccessWeight(DocValuesRangeQuery.this, boost) {
return new ConstantScoreWeight(DocValuesRangeQuery.this, boost) {
@Override
protected Bits getMatchingDocs(LeafReaderContext context) throws IOException {
public Scorer scorer(LeafReaderContext context) throws IOException {
final TwoPhaseIterator iterator = createTwoPhaseIterator(context);
if (iterator == null) {
return null;
}
return new ConstantScoreScorer(this, score(), iterator);
}
private TwoPhaseIterator createTwoPhaseIterator(LeafReaderContext context) throws IOException {
if (lowerVal instanceof Long || upperVal instanceof Long) {
final SortedNumericDocValues values = DocValues.getSortedNumeric(context.reader(), field);
@ -179,32 +187,24 @@ public final class DocValuesRangeQuery extends Query {
return null;
}
return new Bits() {
return new TwoPhaseIterator(values) {
@Override
public boolean get(int doc) {
try {
if (doc > values.docID()) {
values.advance(doc);
public boolean matches() throws IOException {
final int count = values.docValueCount();
assert count > 0;
for (int i = 0; i < count; ++i) {
final long value = values.nextValue();
if (value >= min && value <= max) {
return true;
}
if (doc == values.docID()) {
final int count = values.docValueCount();
for (int i = 0; i < count; ++i) {
final long value = values.nextValue();
if (value >= min && value <= max) {
return true;
}
}
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return false;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 2; // 2 comparisons
}
};
@ -245,32 +245,22 @@ public final class DocValuesRangeQuery extends Query {
return null;
}
return new Bits() {
return new TwoPhaseIterator(values) {
@Override
public boolean get(int doc) {
try {
if (doc > values.docID()) {
values.advance(doc);
public boolean matches() throws IOException {
for (long ord = values.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = values.nextOrd()) {
if (ord >= minOrd && ord <= maxOrd) {
return true;
}
if (doc == values.docID()) {
for (long ord = values.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = values.nextOrd()) {
if (ord >= minOrd && ord <= maxOrd) {
return true;
}
}
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return false;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 2; // 2 comparisons
}
};
} else {

View File

@ -27,7 +27,6 @@ import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.LongBitSet;
@ -149,45 +148,41 @@ public class DocValuesTermsQuery extends Query {
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new RandomAccessWeight(this, boost) {
return new ConstantScoreWeight(this, boost) {
@Override
protected Bits getMatchingDocs(LeafReaderContext context) throws IOException {
public Scorer scorer(LeafReaderContext context) throws IOException {
final SortedSetDocValues values = DocValues.getSortedSet(context.reader(), field);
final LongBitSet bits = new LongBitSet(values.getValueCount());
boolean matchesAtLeastOneTerm = false;
for (BytesRef term : terms) {
final long ord = values.lookupTerm(term);
if (ord >= 0) {
matchesAtLeastOneTerm = true;
bits.set(ord);
}
}
return new Bits() {
if (matchesAtLeastOneTerm == false) {
return null;
}
return new ConstantScoreScorer(this, score(), new TwoPhaseIterator(values) {
@Override
public boolean get(int doc) {
try {
if (doc > values.docID()) {
values.advance(doc);
public boolean matches() throws IOException {
for (long ord = values.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = values.nextOrd()) {
if (bits.get(ord)) {
return true;
}
if (doc == values.docID()) {
for (long ord = values.nextOrd(); ord != SortedSetDocValues.NO_MORE_ORDS; ord = values.nextOrd()) {
if (bits.get(ord)) {
return true;
}
}
}
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
return false;
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
return 3; // lookup in a bitset
}
};
});
}
};
}

View File

@ -30,17 +30,19 @@ import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.RandomAccessWeight;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.spatial.SpatialStrategy;
import org.apache.lucene.spatial.query.SpatialArgs;
import org.apache.lucene.spatial.util.DistanceToShapeValueSource;
import org.apache.lucene.spatial.util.ShapePredicateValueSource;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.locationtech.spatial4j.context.SpatialContext;
@ -136,25 +138,25 @@ public class SerializedDVStrategy extends SpatialStrategy {
@Override
public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException {
return new RandomAccessWeight(this, boost) {
return new ConstantScoreWeight(this, boost) {
@Override
protected Bits getMatchingDocs(LeafReaderContext context) throws IOException {
public Scorer scorer(LeafReaderContext context) throws IOException {
DocIdSetIterator approximation = DocIdSetIterator.all(context.reader().maxDoc());
final FunctionValues predFuncValues = predicateValueSource.getValues(null, context);
return new Bits() {
return new ConstantScoreScorer(this, score(), new TwoPhaseIterator(approximation) {
@Override
public boolean get(int index) {
try {
return predFuncValues.boolVal(index);
} catch (IOException ioe) {
throw new RuntimeException(ioe);
}
public boolean matches() throws IOException {
final int docID = approximation.docID();
return predFuncValues.boolVal(docID);
}
@Override
public int length() {
return context.reader().maxDoc();
public float matchCost() {
// TODO: what is the cost of the predicateValueSource
return 100f;
}
};
});
}
};
}