Moving Weight implementation to nested class from anonymous

This commit is contained in:
Ankit Jain 2024-08-20 10:50:08 -07:00
parent 4404aa3fe9
commit 0080673c99
1 changed files with 480 additions and 438 deletions

View File

@ -49,11 +49,8 @@ import org.apache.lucene.util.IntsRef;
* @lucene.experimental * @lucene.experimental
*/ */
public abstract class PointRangeQuery extends Query { public abstract class PointRangeQuery extends Query {
final String field;
final int numDims; private PRQConfig config;
final int bytesPerDim;
final byte[] lowerPoint;
final byte[] upperPoint;
/** /**
* Expert: create a multidimensional range query for point values. * Expert: create a multidimensional range query for point values.
@ -66,29 +63,7 @@ public abstract class PointRangeQuery extends Query {
* upperValue.length} * upperValue.length}
*/ */
protected PointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) { protected PointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
checkArgs(field, lowerPoint, upperPoint); config = new PRQConfig(field, lowerPoint, upperPoint, numDims);
this.field = field;
if (numDims <= 0) {
throw new IllegalArgumentException("numDims must be positive, got " + numDims);
}
if (lowerPoint.length == 0) {
throw new IllegalArgumentException("lowerPoint has length of zero");
}
if (lowerPoint.length % numDims != 0) {
throw new IllegalArgumentException("lowerPoint is not a fixed multiple of numDims");
}
if (lowerPoint.length != upperPoint.length) {
throw new IllegalArgumentException(
"lowerPoint has length="
+ lowerPoint.length
+ " but upperPoint has different length="
+ upperPoint.length);
}
this.numDims = numDims;
this.bytesPerDim = lowerPoint.length / numDims;
this.lowerPoint = lowerPoint;
this.upperPoint = upperPoint;
} }
/** /**
@ -111,7 +86,7 @@ public abstract class PointRangeQuery extends Query {
@Override @Override
public void visit(QueryVisitor visitor) { public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) { if (visitor.acceptField(config.field())) {
visitor.visitLeaf(this); visitor.visitLeaf(this);
} }
} }
@ -122,423 +97,32 @@ public abstract class PointRangeQuery extends Query {
// We don't use RandomAccessWeight here: it's no good to approximate with "match all docs". // We don't use RandomAccessWeight here: it's no good to approximate with "match all docs".
// This is an inverted structure and should be used in the first pass: // This is an inverted structure and should be used in the first pass:
return new PRQWeightImpl(config, scoreMode, this, boost);
return new ConstantScoreWeight(this, boost) {
private final ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
private boolean matches(byte[] packedValue) {
int offset = 0;
for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) {
if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) {
// Doc's value is too low, in this dimension
return false;
}
if (comparator.compare(packedValue, offset, upperPoint, offset) > 0) {
// Doc's value is too high, in this dimension
return false;
}
}
return true;
}
private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
int offset = 0;
for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) {
if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0
|| comparator.compare(maxPackedValue, offset, lowerPoint, offset) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
crosses |=
comparator.compare(minPackedValue, offset, lowerPoint, offset) < 0
|| comparator.compare(maxPackedValue, offset, upperPoint, offset) > 0;
}
if (crosses) {
return Relation.CELL_CROSSES_QUERY;
} else {
return Relation.CELL_INSIDE_QUERY;
}
}
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
return new IntersectVisitor() {
DocIdSetBuilder.BulkAdder adder;
@Override
public void grow(int count) {
adder = result.grow(count);
}
@Override
public void visit(int docID) {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(IntsRef ref) {
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
adder.add(ref.ints[i]);
}
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue)) {
adder.add(iterator);
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return relate(minPackedValue, maxPackedValue);
}
};
}
/** Create a visitor that clears documents that do NOT match the range. */
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
return new IntersectVisitor() {
@Override
public void visit(int docID) {
result.clear(docID);
cost[0]--;
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
result.andNot(iterator);
cost[0] = Math.max(0, cost[0] - iterator.cost());
}
@Override
public void visit(IntsRef ref) {
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
result.clear(ref.ints[i]);
}
cost[0] -= ref.length;
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue) == false) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue) == false) {
visit(iterator);
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
Relation relation = relate(minPackedValue, maxPackedValue);
switch (relation) {
case CELL_INSIDE_QUERY:
// all points match, skip this subtree
return Relation.CELL_OUTSIDE_QUERY;
case CELL_OUTSIDE_QUERY:
// none of the points match, clear all documents
return Relation.CELL_INSIDE_QUERY;
case CELL_CROSSES_QUERY:
default:
return relation;
}
}
};
}
private boolean checkValidPointValues(PointValues values) throws IOException {
if (values == null) {
// No docs in this segment/field indexed any points
return false;
}
if (values.getNumIndexDimensions() != numDims) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" was indexed with numIndexDimensions="
+ values.getNumIndexDimensions()
+ " but this query has numDims="
+ numDims);
}
if (bytesPerDim != values.getBytesPerDimension()) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" was indexed with bytesPerDim="
+ values.getBytesPerDimension()
+ " but this query has bytesPerDim="
+ bytesPerDim);
}
return true;
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
PointValues values = reader.getPointValues(field);
if (checkValidPointValues(values) == false) {
return null;
}
if (values.getDocCount() == 0) {
return null;
} else {
final byte[] fieldPackedLower = values.getMinPackedValue();
final byte[] fieldPackedUpper = values.getMaxPackedValue();
for (int i = 0; i < numDims; ++i) {
int offset = i * bytesPerDim;
if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0
|| comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) {
// If this query is a required clause of a boolean query, then returning null here
// will help make sure that we don't call ScorerSupplier#get on other required clauses
// of the same boolean query, which is an expensive operation for some queries (e.g.
// multi-term queries).
return null;
}
}
}
boolean allDocsMatch;
if (values.getDocCount() == reader.maxDoc()) {
final byte[] fieldPackedLower = values.getMinPackedValue();
final byte[] fieldPackedUpper = values.getMaxPackedValue();
allDocsMatch = true;
for (int i = 0; i < numDims; ++i) {
int offset = i * bytesPerDim;
if (comparator.compare(lowerPoint, offset, fieldPackedLower, offset) > 0
|| comparator.compare(upperPoint, offset, fieldPackedUpper, offset) < 0) {
allDocsMatch = false;
break;
}
}
} else {
allDocsMatch = false;
}
if (allDocsMatch) {
// all docs have a value and all points are within bounds, so everything matches
return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) {
return new ConstantScoreScorer(
score(), scoreMode, DocIdSetIterator.all(reader.maxDoc()));
}
@Override
public long cost() {
return reader.maxDoc();
}
};
} else {
return new ScorerSupplier() {
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
final IntersectVisitor visitor = getIntersectVisitor(result);
long cost = -1;
@Override
public Scorer get(long leadCost) throws IOException {
if (values.getDocCount() == reader.maxDoc()
&& values.getDocCount() == values.size()
&& cost() > reader.maxDoc() / 2) {
// If all docs have exactly one value and the cost is greater
// than half the leaf size then maybe we can make things faster
// by computing the set of documents that do NOT match the range
final FixedBitSet result = new FixedBitSet(reader.maxDoc());
result.set(0, reader.maxDoc());
long[] cost = new long[] {reader.maxDoc()};
values.intersect(getInverseIntersectVisitor(result, cost));
final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]);
return new ConstantScoreScorer(score(), scoreMode, iterator);
}
values.intersect(visitor);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(score(), scoreMode, iterator);
}
@Override
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;
}
};
}
}
@Override
public int count(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
PointValues values = reader.getPointValues(field);
if (checkValidPointValues(values) == false) {
return 0;
}
if (reader.hasDeletions() == false) {
if (relate(values.getMinPackedValue(), values.getMaxPackedValue())
== Relation.CELL_INSIDE_QUERY) {
return values.getDocCount();
}
// only 1D: we have the guarantee that it will actually run fast since there are at most 2
// crossing leaves.
// docCount == size : counting according number of points in leaf node, so must be
// single-valued.
if (numDims == 1 && values.getDocCount() == values.size()) {
return (int) pointCount(values.getPointTree(), this::relate, this::matches);
}
}
return super.count(context);
}
/**
* Finds the number of points matching the provided range conditions. Using this method is
* faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of
* intersecting points. This method does not enforce live documents, therefore it should only
* be used when there are no deleted documents.
*
* @param pointTree start node of the count operation
* @param nodeComparator comparator to be used for checking whether the internal node is
* inside the range
* @param leafComparator comparator to be used for checking whether the leaf node is inside
* the range
* @return count of points that match the range
*/
private long pointCount(
PointValues.PointTree pointTree,
BiFunction<byte[], byte[], Relation> nodeComparator,
Predicate<byte[]> leafComparator)
throws IOException {
final long[] matchingNodeCount = {0};
// create a custom IntersectVisitor that records the number of leafNodes that matched
final IntersectVisitor visitor =
new IntersectVisitor() {
@Override
public void visit(int docID) {
// this branch should be unreachable
throw new UnsupportedOperationException(
"This IntersectVisitor does not perform any actions on a "
+ "docID="
+ docID
+ " node being visited");
}
@Override
public void visit(int docID, byte[] packedValue) {
if (leafComparator.test(packedValue)) {
matchingNodeCount[0]++;
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return nodeComparator.apply(minPackedValue, maxPackedValue);
}
};
pointCount(visitor, pointTree, matchingNodeCount);
return matchingNodeCount[0];
}
private void pointCount(
IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount)
throws IOException {
Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
switch (r) {
case CELL_OUTSIDE_QUERY:
// This cell is fully outside the query shape: return 0 as the count of its nodes
return;
case CELL_INSIDE_QUERY:
// This cell is fully inside the query shape: return the size of the entire node as the
// count
matchingNodeCount[0] += pointTree.size();
return;
case CELL_CROSSES_QUERY:
/*
The cell crosses the shape boundary, or the cell fully contains the query, so we fall
through and do full counting.
*/
if (pointTree.moveToChild()) {
do {
pointCount(visitor, pointTree, matchingNodeCount);
} while (pointTree.moveToSibling());
pointTree.moveToParent();
} else {
// we have reached a leaf node here.
pointTree.visitDocValues(visitor);
// leaf node count is saved in the matchingNodeCount array by the visitor
}
return;
default:
throw new IllegalArgumentException("Unreachable code");
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
} }
public String getField() { public String getField() {
return field; return config.field();
} }
public int getNumDims() { public int getNumDims() {
return numDims; return config.numDims();
} }
public int getBytesPerDim() { public int getBytesPerDim() {
return bytesPerDim; return config.bytesPerDim();
} }
public byte[] getLowerPoint() { public byte[] getLowerPoint() {
return lowerPoint.clone(); return config.lowerPoint().clone();
} }
public byte[] getUpperPoint() { public byte[] getUpperPoint() {
return upperPoint.clone(); return config.upperPoint().clone();
} }
@Override @Override
public final int hashCode() { public final int hashCode() {
int hash = classHash(); return (31 ^ 5 * classHash() + config.hashCode());
hash = 31 * hash + field.hashCode();
hash = 31 * hash + Arrays.hashCode(lowerPoint);
hash = 31 * hash + Arrays.hashCode(upperPoint);
hash = 31 * hash + numDims;
hash = 31 * hash + Objects.hashCode(bytesPerDim);
return hash;
} }
@Override @Override
@ -547,37 +131,37 @@ public abstract class PointRangeQuery extends Query {
} }
private boolean equalsTo(PointRangeQuery other) { private boolean equalsTo(PointRangeQuery other) {
return Objects.equals(field, other.field) return this.config.equalsTo(other.config);
&& numDims == other.numDims
&& bytesPerDim == other.bytesPerDim
&& Arrays.equals(lowerPoint, other.lowerPoint)
&& Arrays.equals(upperPoint, other.upperPoint);
} }
@Override @Override
public final String toString(String field) { public final String toString(String field) {
final StringBuilder sb = new StringBuilder(); final StringBuilder sb = new StringBuilder();
if (this.field.equals(field) == false) { if (config.field().equals(field) == false) {
sb.append(this.field); sb.append(config.field());
sb.append(':'); sb.append(':');
} }
// print ourselves as "range per dimension" // print ourselves as "range per dimension"
for (int i = 0; i < numDims; i++) { for (int i = 0; i < config.numDims(); i++) {
if (i > 0) { if (i > 0) {
sb.append(','); sb.append(',');
} }
int startOffset = bytesPerDim * i; int startOffset = config.bytesPerDim() * i;
sb.append('['); sb.append('[');
sb.append( sb.append(
toString( toString(
i, ArrayUtil.copyOfSubArray(lowerPoint, startOffset, startOffset + bytesPerDim))); i,
ArrayUtil.copyOfSubArray(
config.lowerPoint(), startOffset, startOffset + config.bytesPerDim())));
sb.append(" TO "); sb.append(" TO ");
sb.append( sb.append(
toString( toString(
i, ArrayUtil.copyOfSubArray(upperPoint, startOffset, startOffset + bytesPerDim))); i,
ArrayUtil.copyOfSubArray(
config.upperPoint(), startOffset, startOffset + config.bytesPerDim())));
sb.append(']'); sb.append(']');
} }
@ -593,4 +177,462 @@ public abstract class PointRangeQuery extends Query {
* @return human readable value for debugging * @return human readable value for debugging
*/ */
protected abstract String toString(int dimension, byte[] value); protected abstract String toString(int dimension, byte[] value);
/**
* Creates config record for PointRangeQuery that can be easily passed to the nested Weight
* Implementation
*
* @param field field name. must not be {@code null}.
* @param lowerPoint lower portion of the range (inclusive).
* @param upperPoint upper portion of the range (inclusive).
* @param numDims number of dimensions.
*/
public record PRQConfig(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
/**
* @throws IllegalArgumentException if {@code field} is null, or if {@code lowerValue.length !=
* upperValue.length}
*/
public PRQConfig {
checkArgs(field, lowerPoint, upperPoint);
if (numDims <= 0) {
throw new IllegalArgumentException("numDims must be positive, got " + numDims);
}
if (lowerPoint.length == 0) {
throw new IllegalArgumentException("lowerPoint has length of zero");
}
if (lowerPoint.length % numDims != 0) {
throw new IllegalArgumentException("lowerPoint is not a fixed multiple of numDims");
}
if (lowerPoint.length != upperPoint.length) {
throw new IllegalArgumentException(
"lowerPoint has length="
+ lowerPoint.length
+ " but upperPoint has different length="
+ upperPoint.length);
}
}
public int bytesPerDim() {
return lowerPoint.length / numDims;
}
@Override
public int hashCode() {
int hash = field.hashCode();
hash = 31 * hash + Arrays.hashCode(lowerPoint);
hash = 31 * hash + Arrays.hashCode(upperPoint);
hash = 31 * hash + numDims;
hash = 31 * hash + Objects.hashCode(bytesPerDim());
return hash;
}
@Override
public boolean equals(Object o) {
return o != null && getClass() == o.getClass() && equalsTo(getClass().cast(o));
}
private boolean equalsTo(PRQConfig other) {
return Objects.equals(field, other.field)
&& numDims == other.numDims
&& bytesPerDim() == other.bytesPerDim()
&& Arrays.equals(lowerPoint, other.lowerPoint)
&& Arrays.equals(upperPoint, other.upperPoint);
}
}
public static class PRQWeightImpl extends ConstantScoreWeight {
private final PRQConfig config;
private final ScoreMode scoreMode;
private final ByteArrayComparator comparator;
public PRQWeightImpl(PRQConfig config, ScoreMode scoreMode, Query query, float boost) {
super(query, boost);
this.config = config;
this.scoreMode = scoreMode;
this.comparator = ArrayUtil.getUnsignedComparator(this.config.bytesPerDim());
}
private boolean matches(byte[] packedValue) {
int offset = 0;
for (int dim = 0; dim < config.numDims(); dim++, offset += config.bytesPerDim()) {
if (comparator.compare(packedValue, offset, config.lowerPoint(), offset) < 0) {
// Doc's value is too low, in this dimension
return false;
}
if (comparator.compare(packedValue, offset, config.upperPoint(), offset) > 0) {
// Doc's value is too high, in this dimension
return false;
}
}
return true;
}
private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
boolean crosses = false;
int offset = 0;
for (int dim = 0; dim < config.numDims(); dim++, offset += config.bytesPerDim()) {
if (comparator.compare(minPackedValue, offset, config.upperPoint(), offset) > 0
|| comparator.compare(maxPackedValue, offset, config.lowerPoint(), offset) < 0) {
return Relation.CELL_OUTSIDE_QUERY;
}
crosses |=
comparator.compare(minPackedValue, offset, config.lowerPoint(), offset) < 0
|| comparator.compare(maxPackedValue, offset, config.upperPoint(), offset) > 0;
}
if (crosses) {
return Relation.CELL_CROSSES_QUERY;
} else {
return Relation.CELL_INSIDE_QUERY;
}
}
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
return new IntersectVisitor() {
DocIdSetBuilder.BulkAdder adder;
@Override
public void grow(int count) {
adder = result.grow(count);
}
@Override
public void visit(int docID) {
adder.add(docID);
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
adder.add(iterator);
}
@Override
public void visit(IntsRef ref) {
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
adder.add(ref.ints[i]);
}
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue)) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue)) {
adder.add(iterator);
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return relate(minPackedValue, maxPackedValue);
}
};
}
/** Create a visitor that clears documents that do NOT match the range. */
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
return new IntersectVisitor() {
@Override
public void visit(int docID) {
result.clear(docID);
cost[0]--;
}
@Override
public void visit(DocIdSetIterator iterator) throws IOException {
result.andNot(iterator);
cost[0] = Math.max(0, cost[0] - iterator.cost());
}
@Override
public void visit(IntsRef ref) {
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
result.clear(ref.ints[i]);
}
cost[0] -= ref.length;
}
@Override
public void visit(int docID, byte[] packedValue) {
if (matches(packedValue) == false) {
visit(docID);
}
}
@Override
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
if (matches(packedValue) == false) {
visit(iterator);
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
Relation relation = relate(minPackedValue, maxPackedValue);
switch (relation) {
case CELL_INSIDE_QUERY:
// all points match, skip this subtree
return Relation.CELL_OUTSIDE_QUERY;
case CELL_OUTSIDE_QUERY:
// none of the points match, clear all documents
return Relation.CELL_INSIDE_QUERY;
case CELL_CROSSES_QUERY:
default:
return relation;
}
}
};
}
private boolean checkValidPointValues(PointValues values) throws IOException {
if (values == null) {
// No docs in this segment/field indexed any points
return false;
}
if (values.getNumIndexDimensions() != config.numDims()) {
throw new IllegalArgumentException(
"field=\""
+ config.field()
+ "\" was indexed with numIndexDimensions="
+ values.getNumIndexDimensions()
+ " but this query has numDims="
+ config.numDims());
}
if (config.bytesPerDim() != values.getBytesPerDimension()) {
throw new IllegalArgumentException(
"field=\""
+ config.field()
+ "\" was indexed with bytesPerDim="
+ values.getBytesPerDimension()
+ " but this query has bytesPerDim="
+ config.bytesPerDim());
}
return true;
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
PointValues values = reader.getPointValues(config.field());
if (checkValidPointValues(values) == false) {
return null;
}
if (values.getDocCount() == 0) {
return null;
} else {
final byte[] fieldPackedLower = values.getMinPackedValue();
final byte[] fieldPackedUpper = values.getMaxPackedValue();
for (int i = 0; i < config.numDims(); ++i) {
int offset = i * config.bytesPerDim();
if (comparator.compare(config.lowerPoint(), offset, fieldPackedUpper, offset) > 0
|| comparator.compare(config.upperPoint(), offset, fieldPackedLower, offset) < 0) {
// If this query is a required clause of a boolean query, then returning null here
// will help make sure that we don't call ScorerSupplier#get on other required clauses
// of the same boolean query, which is an expensive operation for some queries (e.g.
// multi-term queries).
return null;
}
}
}
boolean allDocsMatch;
if (values.getDocCount() == reader.maxDoc()) {
final byte[] fieldPackedLower = values.getMinPackedValue();
final byte[] fieldPackedUpper = values.getMaxPackedValue();
allDocsMatch = true;
for (int i = 0; i < config.numDims(); ++i) {
int offset = i * config.bytesPerDim();
if (comparator.compare(config.lowerPoint(), offset, fieldPackedLower, offset) > 0
|| comparator.compare(config.upperPoint(), offset, fieldPackedUpper, offset) < 0) {
allDocsMatch = false;
break;
}
}
} else {
allDocsMatch = false;
}
if (allDocsMatch) {
// all docs have a value and all points are within bounds, so everything matches
return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) {
return new ConstantScoreScorer(
score(), scoreMode, DocIdSetIterator.all(reader.maxDoc()));
}
@Override
public long cost() {
return reader.maxDoc();
}
};
} else {
return new ScorerSupplier() {
final DocIdSetBuilder result =
new DocIdSetBuilder(reader.maxDoc(), values, config.field());
final IntersectVisitor visitor = getIntersectVisitor(result);
long cost = -1;
@Override
public Scorer get(long leadCost) throws IOException {
if (values.getDocCount() == reader.maxDoc()
&& values.getDocCount() == values.size()
&& cost() > reader.maxDoc() / 2) {
// If all docs have exactly one value and the cost is greater
// than half the leaf size then maybe we can make things faster
// by computing the set of documents that do NOT match the range
final FixedBitSet result = new FixedBitSet(reader.maxDoc());
result.set(0, reader.maxDoc());
long[] cost = new long[] {reader.maxDoc()};
values.intersect(getInverseIntersectVisitor(result, cost));
final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]);
return new ConstantScoreScorer(score(), scoreMode, iterator);
}
values.intersect(visitor);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(score(), scoreMode, iterator);
}
@Override
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;
}
};
}
}
@Override
public int count(LeafReaderContext context) throws IOException {
LeafReader reader = context.reader();
PointValues values = reader.getPointValues(config.field());
if (checkValidPointValues(values) == false) {
return 0;
}
if (reader.hasDeletions() == false) {
if (relate(values.getMinPackedValue(), values.getMaxPackedValue())
== Relation.CELL_INSIDE_QUERY) {
return values.getDocCount();
}
// only 1D: we have the guarantee that it will actually run fast since there are at most 2
// crossing leaves.
// docCount == size : counting according number of points in leaf node, so must be
// single-valued.
if (config.numDims() == 1 && values.getDocCount() == values.size()) {
return (int) pointCount(values.getPointTree(), this::relate, this::matches);
}
}
return super.count(context);
}
/**
* Finds the number of points matching the provided range conditions. Using this method is
* faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of
* intersecting points. This method does not enforce live documents, therefore it should only be
* used when there are no deleted documents.
*
* @param pointTree start node of the count operation
* @param nodeComparator comparator to be used for checking whether the internal node is inside
* the range
* @param leafComparator comparator to be used for checking whether the leaf node is inside the
* range
* @return count of points that match the range
*/
private long pointCount(
PointValues.PointTree pointTree,
BiFunction<byte[], byte[], Relation> nodeComparator,
Predicate<byte[]> leafComparator)
throws IOException {
final long[] matchingNodeCount = {0};
// create a custom IntersectVisitor that records the number of leafNodes that matched
final IntersectVisitor visitor =
new IntersectVisitor() {
@Override
public void visit(int docID) {
// this branch should be unreachable
throw new UnsupportedOperationException(
"This IntersectVisitor does not perform any actions on a "
+ "docID="
+ docID
+ " node being visited");
}
@Override
public void visit(int docID, byte[] packedValue) {
if (leafComparator.test(packedValue)) {
matchingNodeCount[0]++;
}
}
@Override
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
return nodeComparator.apply(minPackedValue, maxPackedValue);
}
};
pointCount(visitor, pointTree, matchingNodeCount);
return matchingNodeCount[0];
}
private void pointCount(
IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount)
throws IOException {
Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
switch (r) {
case CELL_OUTSIDE_QUERY:
// This cell is fully outside the query shape: return 0 as the count of its nodes
return;
case CELL_INSIDE_QUERY:
// This cell is fully inside the query shape: return the size of the entire node as the
// count
matchingNodeCount[0] += pointTree.size();
return;
case CELL_CROSSES_QUERY:
/*
The cell crosses the shape boundary, or the cell fully contains the query, so we fall
through and do full counting.
*/
if (pointTree.moveToChild()) {
do {
pointCount(visitor, pointTree, matchingNodeCount);
} while (pointTree.moveToSibling());
pointTree.moveToParent();
} else {
// we have reached a leaf node here.
pointTree.visitDocValues(visitor);
// leaf node count is saved in the matchingNodeCount array by the visitor
}
return;
default:
throw new IllegalArgumentException("Unreachable code");
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
}
} }