mirror of https://github.com/apache/lucene.git
Moving Weight implementation to nested class from anonymous
This commit is contained in:
parent
4404aa3fe9
commit
0080673c99
|
@ -49,11 +49,8 @@ import org.apache.lucene.util.IntsRef;
|
|||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class PointRangeQuery extends Query {
|
||||
final String field;
|
||||
final int numDims;
|
||||
final int bytesPerDim;
|
||||
final byte[] lowerPoint;
|
||||
final byte[] upperPoint;
|
||||
|
||||
private PRQConfig config;
|
||||
|
||||
/**
|
||||
* Expert: create a multidimensional range query for point values.
|
||||
|
@ -66,29 +63,7 @@ public abstract class PointRangeQuery extends Query {
|
|||
* upperValue.length}
|
||||
*/
|
||||
protected PointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
|
||||
checkArgs(field, lowerPoint, upperPoint);
|
||||
this.field = field;
|
||||
if (numDims <= 0) {
|
||||
throw new IllegalArgumentException("numDims must be positive, got " + numDims);
|
||||
}
|
||||
if (lowerPoint.length == 0) {
|
||||
throw new IllegalArgumentException("lowerPoint has length of zero");
|
||||
}
|
||||
if (lowerPoint.length % numDims != 0) {
|
||||
throw new IllegalArgumentException("lowerPoint is not a fixed multiple of numDims");
|
||||
}
|
||||
if (lowerPoint.length != upperPoint.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"lowerPoint has length="
|
||||
+ lowerPoint.length
|
||||
+ " but upperPoint has different length="
|
||||
+ upperPoint.length);
|
||||
}
|
||||
this.numDims = numDims;
|
||||
this.bytesPerDim = lowerPoint.length / numDims;
|
||||
|
||||
this.lowerPoint = lowerPoint;
|
||||
this.upperPoint = upperPoint;
|
||||
config = new PRQConfig(field, lowerPoint, upperPoint, numDims);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -111,7 +86,7 @@ public abstract class PointRangeQuery extends Query {
|
|||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
if (visitor.acceptField(config.field())) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
@ -122,423 +97,32 @@ public abstract class PointRangeQuery extends Query {
|
|||
|
||||
// We don't use RandomAccessWeight here: it's no good to approximate with "match all docs".
|
||||
// This is an inverted structure and should be used in the first pass:
|
||||
|
||||
return new ConstantScoreWeight(this, boost) {
|
||||
|
||||
private final ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim);
|
||||
|
||||
private boolean matches(byte[] packedValue) {
|
||||
int offset = 0;
|
||||
for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) {
|
||||
if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) {
|
||||
// Doc's value is too low, in this dimension
|
||||
return false;
|
||||
}
|
||||
if (comparator.compare(packedValue, offset, upperPoint, offset) > 0) {
|
||||
// Doc's value is too high, in this dimension
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
|
||||
boolean crosses = false;
|
||||
int offset = 0;
|
||||
|
||||
for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) {
|
||||
|
||||
if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0
|
||||
|| comparator.compare(maxPackedValue, offset, lowerPoint, offset) < 0) {
|
||||
return Relation.CELL_OUTSIDE_QUERY;
|
||||
}
|
||||
|
||||
crosses |=
|
||||
comparator.compare(minPackedValue, offset, lowerPoint, offset) < 0
|
||||
|| comparator.compare(maxPackedValue, offset, upperPoint, offset) > 0;
|
||||
}
|
||||
|
||||
if (crosses) {
|
||||
return Relation.CELL_CROSSES_QUERY;
|
||||
} else {
|
||||
return Relation.CELL_INSIDE_QUERY;
|
||||
}
|
||||
}
|
||||
|
||||
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
|
||||
return new IntersectVisitor() {
|
||||
|
||||
DocIdSetBuilder.BulkAdder adder;
|
||||
|
||||
@Override
|
||||
public void grow(int count) {
|
||||
adder = result.grow(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
adder.add(docID);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator) throws IOException {
|
||||
adder.add(iterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(IntsRef ref) {
|
||||
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
|
||||
adder.add(ref.ints[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (matches(packedValue)) {
|
||||
visit(docID);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
|
||||
if (matches(packedValue)) {
|
||||
adder.add(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
return relate(minPackedValue, maxPackedValue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/** Create a visitor that clears documents that do NOT match the range. */
|
||||
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
|
||||
return new IntersectVisitor() {
|
||||
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
result.clear(docID);
|
||||
cost[0]--;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator) throws IOException {
|
||||
result.andNot(iterator);
|
||||
cost[0] = Math.max(0, cost[0] - iterator.cost());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(IntsRef ref) {
|
||||
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
|
||||
result.clear(ref.ints[i]);
|
||||
}
|
||||
cost[0] -= ref.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (matches(packedValue) == false) {
|
||||
visit(docID);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
|
||||
if (matches(packedValue) == false) {
|
||||
visit(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
Relation relation = relate(minPackedValue, maxPackedValue);
|
||||
switch (relation) {
|
||||
case CELL_INSIDE_QUERY:
|
||||
// all points match, skip this subtree
|
||||
return Relation.CELL_OUTSIDE_QUERY;
|
||||
case CELL_OUTSIDE_QUERY:
|
||||
// none of the points match, clear all documents
|
||||
return Relation.CELL_INSIDE_QUERY;
|
||||
case CELL_CROSSES_QUERY:
|
||||
default:
|
||||
return relation;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private boolean checkValidPointValues(PointValues values) throws IOException {
|
||||
if (values == null) {
|
||||
// No docs in this segment/field indexed any points
|
||||
return false;
|
||||
}
|
||||
|
||||
if (values.getNumIndexDimensions() != numDims) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" was indexed with numIndexDimensions="
|
||||
+ values.getNumIndexDimensions()
|
||||
+ " but this query has numDims="
|
||||
+ numDims);
|
||||
}
|
||||
if (bytesPerDim != values.getBytesPerDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" was indexed with bytesPerDim="
|
||||
+ values.getBytesPerDimension()
|
||||
+ " but this query has bytesPerDim="
|
||||
+ bytesPerDim);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
LeafReader reader = context.reader();
|
||||
|
||||
PointValues values = reader.getPointValues(field);
|
||||
if (checkValidPointValues(values) == false) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (values.getDocCount() == 0) {
|
||||
return null;
|
||||
} else {
|
||||
final byte[] fieldPackedLower = values.getMinPackedValue();
|
||||
final byte[] fieldPackedUpper = values.getMaxPackedValue();
|
||||
for (int i = 0; i < numDims; ++i) {
|
||||
int offset = i * bytesPerDim;
|
||||
if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0
|
||||
|| comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) {
|
||||
// If this query is a required clause of a boolean query, then returning null here
|
||||
// will help make sure that we don't call ScorerSupplier#get on other required clauses
|
||||
// of the same boolean query, which is an expensive operation for some queries (e.g.
|
||||
// multi-term queries).
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
boolean allDocsMatch;
|
||||
if (values.getDocCount() == reader.maxDoc()) {
|
||||
final byte[] fieldPackedLower = values.getMinPackedValue();
|
||||
final byte[] fieldPackedUpper = values.getMaxPackedValue();
|
||||
allDocsMatch = true;
|
||||
for (int i = 0; i < numDims; ++i) {
|
||||
int offset = i * bytesPerDim;
|
||||
if (comparator.compare(lowerPoint, offset, fieldPackedLower, offset) > 0
|
||||
|| comparator.compare(upperPoint, offset, fieldPackedUpper, offset) < 0) {
|
||||
allDocsMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allDocsMatch = false;
|
||||
}
|
||||
|
||||
if (allDocsMatch) {
|
||||
// all docs have a value and all points are within bounds, so everything matches
|
||||
return new ScorerSupplier() {
|
||||
@Override
|
||||
public Scorer get(long leadCost) {
|
||||
return new ConstantScoreScorer(
|
||||
score(), scoreMode, DocIdSetIterator.all(reader.maxDoc()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return reader.maxDoc();
|
||||
}
|
||||
};
|
||||
} else {
|
||||
return new ScorerSupplier() {
|
||||
|
||||
final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
|
||||
final IntersectVisitor visitor = getIntersectVisitor(result);
|
||||
long cost = -1;
|
||||
|
||||
@Override
|
||||
public Scorer get(long leadCost) throws IOException {
|
||||
if (values.getDocCount() == reader.maxDoc()
|
||||
&& values.getDocCount() == values.size()
|
||||
&& cost() > reader.maxDoc() / 2) {
|
||||
// If all docs have exactly one value and the cost is greater
|
||||
// than half the leaf size then maybe we can make things faster
|
||||
// by computing the set of documents that do NOT match the range
|
||||
final FixedBitSet result = new FixedBitSet(reader.maxDoc());
|
||||
result.set(0, reader.maxDoc());
|
||||
long[] cost = new long[] {reader.maxDoc()};
|
||||
values.intersect(getInverseIntersectVisitor(result, cost));
|
||||
final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]);
|
||||
return new ConstantScoreScorer(score(), scoreMode, iterator);
|
||||
}
|
||||
|
||||
values.intersect(visitor);
|
||||
DocIdSetIterator iterator = result.build().iterator();
|
||||
return new ConstantScoreScorer(score(), scoreMode, iterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
if (cost == -1) {
|
||||
// Computing the cost may be expensive, so only do it if necessary
|
||||
cost = values.estimateDocCount(visitor);
|
||||
assert cost >= 0;
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int count(LeafReaderContext context) throws IOException {
|
||||
LeafReader reader = context.reader();
|
||||
|
||||
PointValues values = reader.getPointValues(field);
|
||||
if (checkValidPointValues(values) == false) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (reader.hasDeletions() == false) {
|
||||
if (relate(values.getMinPackedValue(), values.getMaxPackedValue())
|
||||
== Relation.CELL_INSIDE_QUERY) {
|
||||
return values.getDocCount();
|
||||
}
|
||||
// only 1D: we have the guarantee that it will actually run fast since there are at most 2
|
||||
// crossing leaves.
|
||||
// docCount == size : counting according number of points in leaf node, so must be
|
||||
// single-valued.
|
||||
if (numDims == 1 && values.getDocCount() == values.size()) {
|
||||
return (int) pointCount(values.getPointTree(), this::relate, this::matches);
|
||||
}
|
||||
}
|
||||
return super.count(context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the number of points matching the provided range conditions. Using this method is
|
||||
* faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of
|
||||
* intersecting points. This method does not enforce live documents, therefore it should only
|
||||
* be used when there are no deleted documents.
|
||||
*
|
||||
* @param pointTree start node of the count operation
|
||||
* @param nodeComparator comparator to be used for checking whether the internal node is
|
||||
* inside the range
|
||||
* @param leafComparator comparator to be used for checking whether the leaf node is inside
|
||||
* the range
|
||||
* @return count of points that match the range
|
||||
*/
|
||||
private long pointCount(
|
||||
PointValues.PointTree pointTree,
|
||||
BiFunction<byte[], byte[], Relation> nodeComparator,
|
||||
Predicate<byte[]> leafComparator)
|
||||
throws IOException {
|
||||
final long[] matchingNodeCount = {0};
|
||||
// create a custom IntersectVisitor that records the number of leafNodes that matched
|
||||
final IntersectVisitor visitor =
|
||||
new IntersectVisitor() {
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
// this branch should be unreachable
|
||||
throw new UnsupportedOperationException(
|
||||
"This IntersectVisitor does not perform any actions on a "
|
||||
+ "docID="
|
||||
+ docID
|
||||
+ " node being visited");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (leafComparator.test(packedValue)) {
|
||||
matchingNodeCount[0]++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
return nodeComparator.apply(minPackedValue, maxPackedValue);
|
||||
}
|
||||
};
|
||||
pointCount(visitor, pointTree, matchingNodeCount);
|
||||
return matchingNodeCount[0];
|
||||
}
|
||||
|
||||
private void pointCount(
|
||||
IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount)
|
||||
throws IOException {
|
||||
Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
|
||||
switch (r) {
|
||||
case CELL_OUTSIDE_QUERY:
|
||||
// This cell is fully outside the query shape: return 0 as the count of its nodes
|
||||
return;
|
||||
case CELL_INSIDE_QUERY:
|
||||
// This cell is fully inside the query shape: return the size of the entire node as the
|
||||
// count
|
||||
matchingNodeCount[0] += pointTree.size();
|
||||
return;
|
||||
case CELL_CROSSES_QUERY:
|
||||
/*
|
||||
The cell crosses the shape boundary, or the cell fully contains the query, so we fall
|
||||
through and do full counting.
|
||||
*/
|
||||
if (pointTree.moveToChild()) {
|
||||
do {
|
||||
pointCount(visitor, pointTree, matchingNodeCount);
|
||||
} while (pointTree.moveToSibling());
|
||||
pointTree.moveToParent();
|
||||
} else {
|
||||
// we have reached a leaf node here.
|
||||
pointTree.visitDocValues(visitor);
|
||||
// leaf node count is saved in the matchingNodeCount array by the visitor
|
||||
}
|
||||
return;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unreachable code");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
return new PRQWeightImpl(config, scoreMode, this, boost);
|
||||
}
|
||||
|
||||
public String getField() {
|
||||
return field;
|
||||
return config.field();
|
||||
}
|
||||
|
||||
public int getNumDims() {
|
||||
return numDims;
|
||||
return config.numDims();
|
||||
}
|
||||
|
||||
public int getBytesPerDim() {
|
||||
return bytesPerDim;
|
||||
return config.bytesPerDim();
|
||||
}
|
||||
|
||||
public byte[] getLowerPoint() {
|
||||
return lowerPoint.clone();
|
||||
return config.lowerPoint().clone();
|
||||
}
|
||||
|
||||
public byte[] getUpperPoint() {
|
||||
return upperPoint.clone();
|
||||
return config.upperPoint().clone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
int hash = classHash();
|
||||
hash = 31 * hash + field.hashCode();
|
||||
hash = 31 * hash + Arrays.hashCode(lowerPoint);
|
||||
hash = 31 * hash + Arrays.hashCode(upperPoint);
|
||||
hash = 31 * hash + numDims;
|
||||
hash = 31 * hash + Objects.hashCode(bytesPerDim);
|
||||
return hash;
|
||||
return (31 ^ 5 * classHash() + config.hashCode());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -547,37 +131,37 @@ public abstract class PointRangeQuery extends Query {
|
|||
}
|
||||
|
||||
private boolean equalsTo(PointRangeQuery other) {
|
||||
return Objects.equals(field, other.field)
|
||||
&& numDims == other.numDims
|
||||
&& bytesPerDim == other.bytesPerDim
|
||||
&& Arrays.equals(lowerPoint, other.lowerPoint)
|
||||
&& Arrays.equals(upperPoint, other.upperPoint);
|
||||
return this.config.equalsTo(other.config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString(String field) {
|
||||
final StringBuilder sb = new StringBuilder();
|
||||
if (this.field.equals(field) == false) {
|
||||
sb.append(this.field);
|
||||
if (config.field().equals(field) == false) {
|
||||
sb.append(config.field());
|
||||
sb.append(':');
|
||||
}
|
||||
|
||||
// print ourselves as "range per dimension"
|
||||
for (int i = 0; i < numDims; i++) {
|
||||
for (int i = 0; i < config.numDims(); i++) {
|
||||
if (i > 0) {
|
||||
sb.append(',');
|
||||
}
|
||||
|
||||
int startOffset = bytesPerDim * i;
|
||||
int startOffset = config.bytesPerDim() * i;
|
||||
|
||||
sb.append('[');
|
||||
sb.append(
|
||||
toString(
|
||||
i, ArrayUtil.copyOfSubArray(lowerPoint, startOffset, startOffset + bytesPerDim)));
|
||||
i,
|
||||
ArrayUtil.copyOfSubArray(
|
||||
config.lowerPoint(), startOffset, startOffset + config.bytesPerDim())));
|
||||
sb.append(" TO ");
|
||||
sb.append(
|
||||
toString(
|
||||
i, ArrayUtil.copyOfSubArray(upperPoint, startOffset, startOffset + bytesPerDim)));
|
||||
i,
|
||||
ArrayUtil.copyOfSubArray(
|
||||
config.upperPoint(), startOffset, startOffset + config.bytesPerDim())));
|
||||
sb.append(']');
|
||||
}
|
||||
|
||||
|
@ -593,4 +177,462 @@ public abstract class PointRangeQuery extends Query {
|
|||
* @return human readable value for debugging
|
||||
*/
|
||||
protected abstract String toString(int dimension, byte[] value);
|
||||
|
||||
/**
|
||||
* Creates config record for PointRangeQuery that can be easily passed to the nested Weight
|
||||
* Implementation
|
||||
*
|
||||
* @param field field name. must not be {@code null}.
|
||||
* @param lowerPoint lower portion of the range (inclusive).
|
||||
* @param upperPoint upper portion of the range (inclusive).
|
||||
* @param numDims number of dimensions.
|
||||
*/
|
||||
public record PRQConfig(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
|
||||
/**
|
||||
* @throws IllegalArgumentException if {@code field} is null, or if {@code lowerValue.length !=
|
||||
* upperValue.length}
|
||||
*/
|
||||
public PRQConfig {
|
||||
checkArgs(field, lowerPoint, upperPoint);
|
||||
if (numDims <= 0) {
|
||||
throw new IllegalArgumentException("numDims must be positive, got " + numDims);
|
||||
}
|
||||
if (lowerPoint.length == 0) {
|
||||
throw new IllegalArgumentException("lowerPoint has length of zero");
|
||||
}
|
||||
if (lowerPoint.length % numDims != 0) {
|
||||
throw new IllegalArgumentException("lowerPoint is not a fixed multiple of numDims");
|
||||
}
|
||||
if (lowerPoint.length != upperPoint.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"lowerPoint has length="
|
||||
+ lowerPoint.length
|
||||
+ " but upperPoint has different length="
|
||||
+ upperPoint.length);
|
||||
}
|
||||
}
|
||||
|
||||
public int bytesPerDim() {
|
||||
return lowerPoint.length / numDims;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int hash = field.hashCode();
|
||||
hash = 31 * hash + Arrays.hashCode(lowerPoint);
|
||||
hash = 31 * hash + Arrays.hashCode(upperPoint);
|
||||
hash = 31 * hash + numDims;
|
||||
hash = 31 * hash + Objects.hashCode(bytesPerDim());
|
||||
return hash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return o != null && getClass() == o.getClass() && equalsTo(getClass().cast(o));
|
||||
}
|
||||
|
||||
private boolean equalsTo(PRQConfig other) {
|
||||
return Objects.equals(field, other.field)
|
||||
&& numDims == other.numDims
|
||||
&& bytesPerDim() == other.bytesPerDim()
|
||||
&& Arrays.equals(lowerPoint, other.lowerPoint)
|
||||
&& Arrays.equals(upperPoint, other.upperPoint);
|
||||
}
|
||||
}
|
||||
|
||||
public static class PRQWeightImpl extends ConstantScoreWeight {
|
||||
private final PRQConfig config;
|
||||
private final ScoreMode scoreMode;
|
||||
|
||||
private final ByteArrayComparator comparator;
|
||||
|
||||
public PRQWeightImpl(PRQConfig config, ScoreMode scoreMode, Query query, float boost) {
|
||||
super(query, boost);
|
||||
this.config = config;
|
||||
this.scoreMode = scoreMode;
|
||||
this.comparator = ArrayUtil.getUnsignedComparator(this.config.bytesPerDim());
|
||||
}
|
||||
|
||||
private boolean matches(byte[] packedValue) {
|
||||
int offset = 0;
|
||||
for (int dim = 0; dim < config.numDims(); dim++, offset += config.bytesPerDim()) {
|
||||
if (comparator.compare(packedValue, offset, config.lowerPoint(), offset) < 0) {
|
||||
// Doc's value is too low, in this dimension
|
||||
return false;
|
||||
}
|
||||
if (comparator.compare(packedValue, offset, config.upperPoint(), offset) > 0) {
|
||||
// Doc's value is too high, in this dimension
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
|
||||
boolean crosses = false;
|
||||
int offset = 0;
|
||||
|
||||
for (int dim = 0; dim < config.numDims(); dim++, offset += config.bytesPerDim()) {
|
||||
|
||||
if (comparator.compare(minPackedValue, offset, config.upperPoint(), offset) > 0
|
||||
|| comparator.compare(maxPackedValue, offset, config.lowerPoint(), offset) < 0) {
|
||||
return Relation.CELL_OUTSIDE_QUERY;
|
||||
}
|
||||
|
||||
crosses |=
|
||||
comparator.compare(minPackedValue, offset, config.lowerPoint(), offset) < 0
|
||||
|| comparator.compare(maxPackedValue, offset, config.upperPoint(), offset) > 0;
|
||||
}
|
||||
|
||||
if (crosses) {
|
||||
return Relation.CELL_CROSSES_QUERY;
|
||||
} else {
|
||||
return Relation.CELL_INSIDE_QUERY;
|
||||
}
|
||||
}
|
||||
|
||||
private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) {
|
||||
return new IntersectVisitor() {
|
||||
|
||||
DocIdSetBuilder.BulkAdder adder;
|
||||
|
||||
@Override
|
||||
public void grow(int count) {
|
||||
adder = result.grow(count);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
adder.add(docID);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator) throws IOException {
|
||||
adder.add(iterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(IntsRef ref) {
|
||||
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
|
||||
adder.add(ref.ints[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (matches(packedValue)) {
|
||||
visit(docID);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
|
||||
if (matches(packedValue)) {
|
||||
adder.add(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
return relate(minPackedValue, maxPackedValue);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/** Create a visitor that clears documents that do NOT match the range. */
|
||||
private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) {
|
||||
return new IntersectVisitor() {
|
||||
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
result.clear(docID);
|
||||
cost[0]--;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator) throws IOException {
|
||||
result.andNot(iterator);
|
||||
cost[0] = Math.max(0, cost[0] - iterator.cost());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(IntsRef ref) {
|
||||
for (int i = ref.offset; i < ref.offset + ref.length; i++) {
|
||||
result.clear(ref.ints[i]);
|
||||
}
|
||||
cost[0] -= ref.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (matches(packedValue) == false) {
|
||||
visit(docID);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
|
||||
if (matches(packedValue) == false) {
|
||||
visit(iterator);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
Relation relation = relate(minPackedValue, maxPackedValue);
|
||||
switch (relation) {
|
||||
case CELL_INSIDE_QUERY:
|
||||
// all points match, skip this subtree
|
||||
return Relation.CELL_OUTSIDE_QUERY;
|
||||
case CELL_OUTSIDE_QUERY:
|
||||
// none of the points match, clear all documents
|
||||
return Relation.CELL_INSIDE_QUERY;
|
||||
case CELL_CROSSES_QUERY:
|
||||
default:
|
||||
return relation;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private boolean checkValidPointValues(PointValues values) throws IOException {
|
||||
if (values == null) {
|
||||
// No docs in this segment/field indexed any points
|
||||
return false;
|
||||
}
|
||||
|
||||
if (values.getNumIndexDimensions() != config.numDims()) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ config.field()
|
||||
+ "\" was indexed with numIndexDimensions="
|
||||
+ values.getNumIndexDimensions()
|
||||
+ " but this query has numDims="
|
||||
+ config.numDims());
|
||||
}
|
||||
if (config.bytesPerDim() != values.getBytesPerDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ config.field()
|
||||
+ "\" was indexed with bytesPerDim="
|
||||
+ values.getBytesPerDimension()
|
||||
+ " but this query has bytesPerDim="
|
||||
+ config.bytesPerDim());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
|
||||
LeafReader reader = context.reader();
|
||||
|
||||
PointValues values = reader.getPointValues(config.field());
|
||||
if (checkValidPointValues(values) == false) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (values.getDocCount() == 0) {
|
||||
return null;
|
||||
} else {
|
||||
final byte[] fieldPackedLower = values.getMinPackedValue();
|
||||
final byte[] fieldPackedUpper = values.getMaxPackedValue();
|
||||
for (int i = 0; i < config.numDims(); ++i) {
|
||||
int offset = i * config.bytesPerDim();
|
||||
if (comparator.compare(config.lowerPoint(), offset, fieldPackedUpper, offset) > 0
|
||||
|| comparator.compare(config.upperPoint(), offset, fieldPackedLower, offset) < 0) {
|
||||
// If this query is a required clause of a boolean query, then returning null here
|
||||
// will help make sure that we don't call ScorerSupplier#get on other required clauses
|
||||
// of the same boolean query, which is an expensive operation for some queries (e.g.
|
||||
// multi-term queries).
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
boolean allDocsMatch;
|
||||
if (values.getDocCount() == reader.maxDoc()) {
|
||||
final byte[] fieldPackedLower = values.getMinPackedValue();
|
||||
final byte[] fieldPackedUpper = values.getMaxPackedValue();
|
||||
allDocsMatch = true;
|
||||
for (int i = 0; i < config.numDims(); ++i) {
|
||||
int offset = i * config.bytesPerDim();
|
||||
if (comparator.compare(config.lowerPoint(), offset, fieldPackedLower, offset) > 0
|
||||
|| comparator.compare(config.upperPoint(), offset, fieldPackedUpper, offset) < 0) {
|
||||
allDocsMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allDocsMatch = false;
|
||||
}
|
||||
|
||||
if (allDocsMatch) {
|
||||
// all docs have a value and all points are within bounds, so everything matches
|
||||
return new ScorerSupplier() {
|
||||
@Override
|
||||
public Scorer get(long leadCost) {
|
||||
return new ConstantScoreScorer(
|
||||
score(), scoreMode, DocIdSetIterator.all(reader.maxDoc()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return reader.maxDoc();
|
||||
}
|
||||
};
|
||||
} else {
|
||||
return new ScorerSupplier() {
|
||||
|
||||
final DocIdSetBuilder result =
|
||||
new DocIdSetBuilder(reader.maxDoc(), values, config.field());
|
||||
final IntersectVisitor visitor = getIntersectVisitor(result);
|
||||
long cost = -1;
|
||||
|
||||
@Override
|
||||
public Scorer get(long leadCost) throws IOException {
|
||||
if (values.getDocCount() == reader.maxDoc()
|
||||
&& values.getDocCount() == values.size()
|
||||
&& cost() > reader.maxDoc() / 2) {
|
||||
// If all docs have exactly one value and the cost is greater
|
||||
// than half the leaf size then maybe we can make things faster
|
||||
// by computing the set of documents that do NOT match the range
|
||||
final FixedBitSet result = new FixedBitSet(reader.maxDoc());
|
||||
result.set(0, reader.maxDoc());
|
||||
long[] cost = new long[] {reader.maxDoc()};
|
||||
values.intersect(getInverseIntersectVisitor(result, cost));
|
||||
final DocIdSetIterator iterator = new BitSetIterator(result, cost[0]);
|
||||
return new ConstantScoreScorer(score(), scoreMode, iterator);
|
||||
}
|
||||
|
||||
values.intersect(visitor);
|
||||
DocIdSetIterator iterator = result.build().iterator();
|
||||
return new ConstantScoreScorer(score(), scoreMode, iterator);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
if (cost == -1) {
|
||||
// Computing the cost may be expensive, so only do it if necessary
|
||||
cost = values.estimateDocCount(visitor);
|
||||
assert cost >= 0;
|
||||
}
|
||||
return cost;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int count(LeafReaderContext context) throws IOException {
|
||||
LeafReader reader = context.reader();
|
||||
|
||||
PointValues values = reader.getPointValues(config.field());
|
||||
if (checkValidPointValues(values) == false) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (reader.hasDeletions() == false) {
|
||||
if (relate(values.getMinPackedValue(), values.getMaxPackedValue())
|
||||
== Relation.CELL_INSIDE_QUERY) {
|
||||
return values.getDocCount();
|
||||
}
|
||||
// only 1D: we have the guarantee that it will actually run fast since there are at most 2
|
||||
// crossing leaves.
|
||||
// docCount == size : counting according number of points in leaf node, so must be
|
||||
// single-valued.
|
||||
if (config.numDims() == 1 && values.getDocCount() == values.size()) {
|
||||
return (int) pointCount(values.getPointTree(), this::relate, this::matches);
|
||||
}
|
||||
}
|
||||
return super.count(context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the number of points matching the provided range conditions. Using this method is
|
||||
* faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of
|
||||
* intersecting points. This method does not enforce live documents, therefore it should only be
|
||||
* used when there are no deleted documents.
|
||||
*
|
||||
* @param pointTree start node of the count operation
|
||||
* @param nodeComparator comparator to be used for checking whether the internal node is inside
|
||||
* the range
|
||||
* @param leafComparator comparator to be used for checking whether the leaf node is inside the
|
||||
* range
|
||||
* @return count of points that match the range
|
||||
*/
|
||||
private long pointCount(
|
||||
PointValues.PointTree pointTree,
|
||||
BiFunction<byte[], byte[], Relation> nodeComparator,
|
||||
Predicate<byte[]> leafComparator)
|
||||
throws IOException {
|
||||
final long[] matchingNodeCount = {0};
|
||||
// create a custom IntersectVisitor that records the number of leafNodes that matched
|
||||
final IntersectVisitor visitor =
|
||||
new IntersectVisitor() {
|
||||
@Override
|
||||
public void visit(int docID) {
|
||||
// this branch should be unreachable
|
||||
throw new UnsupportedOperationException(
|
||||
"This IntersectVisitor does not perform any actions on a "
|
||||
+ "docID="
|
||||
+ docID
|
||||
+ " node being visited");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(int docID, byte[] packedValue) {
|
||||
if (leafComparator.test(packedValue)) {
|
||||
matchingNodeCount[0]++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
|
||||
return nodeComparator.apply(minPackedValue, maxPackedValue);
|
||||
}
|
||||
};
|
||||
pointCount(visitor, pointTree, matchingNodeCount);
|
||||
return matchingNodeCount[0];
|
||||
}
|
||||
|
||||
private void pointCount(
|
||||
IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount)
|
||||
throws IOException {
|
||||
Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
|
||||
switch (r) {
|
||||
case CELL_OUTSIDE_QUERY:
|
||||
// This cell is fully outside the query shape: return 0 as the count of its nodes
|
||||
return;
|
||||
case CELL_INSIDE_QUERY:
|
||||
// This cell is fully inside the query shape: return the size of the entire node as the
|
||||
// count
|
||||
matchingNodeCount[0] += pointTree.size();
|
||||
return;
|
||||
case CELL_CROSSES_QUERY:
|
||||
/*
|
||||
The cell crosses the shape boundary, or the cell fully contains the query, so we fall
|
||||
through and do full counting.
|
||||
*/
|
||||
if (pointTree.moveToChild()) {
|
||||
do {
|
||||
pointCount(visitor, pointTree, matchingNodeCount);
|
||||
} while (pointTree.moveToSibling());
|
||||
pointTree.moveToParent();
|
||||
} else {
|
||||
// we have reached a leaf node here.
|
||||
pointTree.visitDocValues(visitor);
|
||||
// leaf node count is saved in the matchingNodeCount array by the visitor
|
||||
}
|
||||
return;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unreachable code");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue