diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index bb3e19d24b2..b89446d0536 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -15,6 +15,11 @@ Changes in Runtime Behavior ======================= Lucene 7.1.0 ======================= +New Features + +* LUCENE-7621: Add CoveringQuery, a query whose required number of matching + clauses can be defined per document. (Adrien Grand) + Optimizations * LUCENE-7905: Optimize how OrdinalMap (used by diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java new file mode 100644 index 00000000000..288e05b05bf --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringQuery.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; + +/** A {@link Query} that allows to have a configurable number or required + * matches per document. This is typically useful in order to build queries + * whose query terms must all appear in documents. + * @lucene.experimental + */ +public final class CoveringQuery extends Query { + + private final Collection queries; + private final LongValuesSource minimumNumberMatch; + private final int hashCode; + + /** + * Sole constructor. + * @param queries Sub queries to match. + * @param minimumNumberMatch Per-document long value that records how many queries + * should match. Values that are less than 1 are treated + * like 1: only documents that have at least one + * matching clause will be considered matches. Documents + * that do not have a value for minimumNumberMatch + * do not match. + */ + public CoveringQuery(Collection queries, LongValuesSource minimumNumberMatch) { + if (queries.size() > BooleanQuery.getMaxClauseCount()) { + throw new BooleanQuery.TooManyClauses(); + } + if (minimumNumberMatch.needsScores()) { + throw new IllegalArgumentException("The minimum number of matches may not depend on the score."); + } + this.queries = new Multiset<>(); + this.queries.addAll(queries); + this.minimumNumberMatch = Objects.requireNonNull(minimumNumberMatch); + this.hashCode = computeHashCode(); + } + + @Override + public String toString(String field) { + String queriesToString = queries.stream() + .map(q -> q.toString(field)) + .sorted() + .collect(Collectors.joining(", ")); + return "CoveringQuery(queries=[" + queriesToString + "], minimumNumberMatch=" + minimumNumberMatch + ")"; + } + + @Override + public boolean equals(Object obj) { + if (sameClassAs(obj) == false) { + return false; + } + CoveringQuery that = (CoveringQuery) obj; + return hashCode == that.hashCode // not necessary but makes equals faster + && Objects.equals(queries, that.queries) + && Objects.equals(minimumNumberMatch, that.minimumNumberMatch); + } + + private int computeHashCode() { + int h = classHash(); + h = 31 * h + queries.hashCode(); + h = 31 * h + minimumNumberMatch.hashCode(); + return h; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public Query rewrite(IndexReader reader) throws IOException { + Multiset rewritten = new Multiset<>(); + boolean actuallyRewritten = false; + for (Query query : queries) { + Query r = query.rewrite(reader); + rewritten.add(r); + actuallyRewritten |= query != r; + } + if (actuallyRewritten) { + return new CoveringQuery(rewritten, minimumNumberMatch); + } + return super.rewrite(reader); + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { + final List weights = new ArrayList<>(queries.size()); + for (Query query : queries) { + weights.add(searcher.createWeight(query, needsScores, boost)); + } + return new CoveringWeight(this, weights, minimumNumberMatch); + } + + private static class CoveringWeight extends Weight { + + private final Collection weights; + private final LongValuesSource minimumNumberMatch; + + CoveringWeight(Query query, Collection weights, LongValuesSource minimumNumberMatch) { + super(query); + this.weights = weights; + this.minimumNumberMatch = minimumNumberMatch; + } + + @Override + public void extractTerms(Set terms) { + for (Weight weight : weights) { + weight.extractTerms(terms); + } + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + LongValues minMatchValues = minimumNumberMatch.getValues(context, null); + if (minMatchValues.advanceExact(doc) == false) { + return Explanation.noMatch("minimumNumberMatch has no value on the current document"); + } + final long minimumNumberMatch = Math.max(1, minMatchValues.longValue()); + int freq = 0; + double score = 0; + List subExpls = new ArrayList<>(); + for (Weight weight : weights) { + Explanation subExpl = weight.explain(context, doc); + if (subExpl.isMatch()) { + freq++; + score += subExpl.getValue(); + } + subExpls.add(subExpl); + } + if (freq >= minimumNumberMatch) { + return Explanation.match((float) score, freq + " matches for " + minimumNumberMatch + " required matches, sum of:", subExpls); + } else { + return Explanation.noMatch(freq + " matches for " + minimumNumberMatch + " required matches", subExpls); + } + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Collection scorers = new ArrayList<>(); + for (Weight w : weights) { + Scorer s = w.scorer(context); + if (s != null) { + scorers.add(s); + } + } + if (scorers.isEmpty()) { + return null; + } + return new CoveringScorer(this, scorers, minimumNumberMatch.getValues(context, null), context.reader().maxDoc()); + } + } + +} diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/CoveringScorer.java b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringScorer.java new file mode 100644 index 00000000000..8f62d236d46 --- /dev/null +++ b/lucene/sandbox/src/java/org/apache/lucene/search/CoveringScorer.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** A {@link Scorer} whose number of matches is per-document. */ +final class CoveringScorer extends Scorer { + + final int numScorers; + final int maxDoc; + final LongValues minMatchValues; + + boolean matches; // if true then the doc matches, otherwise we don't know and need to check + int doc; // current doc ID + DisiWrapper topList; // list of matches + int freq; // number of scorers on the desired doc ID + long minMatch; // current required number of matches + + // priority queue that stores all scorers + final DisiPriorityQueue subScorers; + + final long cost; + + CoveringScorer(Weight weight, Collection scorers, LongValues minMatchValues, int maxDoc) { + super(weight); + + this.numScorers = scorers.size(); + this.maxDoc = maxDoc; + this.minMatchValues = minMatchValues; + this.doc = -1; + + subScorers = new DisiPriorityQueue(scorers.size()); + + for (Scorer scorer : scorers) { + subScorers.add(new DisiWrapper(scorer)); + } + + this.cost = scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost).sum(); + } + + @Override + public final Collection getChildren() throws IOException { + List matchingChildren = new ArrayList<>(); + setTopListAndFreqIfNecessary(); + for (DisiWrapper s = topList; s != null; s = s.next) { + matchingChildren.add(new ChildScorer(s.scorer, "SHOULD")); + } + return matchingChildren; + } + + private final DocIdSetIterator approximation = new DocIdSetIterator() { + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + return advance(docID() + 1); + } + + @Override + public int advance(int target) throws IOException { + // reset state + matches = false; + topList = null; + + doc = target; + setMinMatch(); + + DisiWrapper top = subScorers.top(); + int numMatches = 0; + int maxPotentialMatches = numScorers; + while (top.doc < target) { + if (maxPotentialMatches < minMatch) { + // No need to keep trying to advance to `target` since no match is possible. + if (target >= maxDoc - 1) { + doc = NO_MORE_DOCS; + } else { + doc = target + 1; + } + setMinMatch(); + return doc; + } + top.doc = top.iterator.advance(target); + boolean match = top.doc == target; + top = subScorers.updateTop(); + if (match) { + numMatches++; + if (numMatches >= minMatch) { + // success, no need to check other iterators + matches = true; + return doc; + } + } else { + maxPotentialMatches--; + } + } + + doc = top.doc; + setMinMatch(); + return doc; + } + + private void setMinMatch() throws IOException { + if (doc >= maxDoc) { + // advanceExact may not be called on out-of-range doc ids + minMatch = 1; + } else if (minMatchValues.advanceExact(doc)) { + // values < 1 are treated as 1: we require at least one match + minMatch = Math.max(1, minMatchValues.longValue()); + } else { + // this will make sure the document does not match + minMatch = Long.MAX_VALUE; + } + } + + @Override + public long cost() { + return maxDoc; + } + + }; + + private final TwoPhaseIterator twoPhase = new TwoPhaseIterator(approximation) { + + @Override + public boolean matches() throws IOException { + if (matches) { + return true; + } + if (topList == null) { + advanceAll(doc); + } + if (subScorers.top().doc != doc) { + assert subScorers.top().doc > doc; + return false; + } + setTopListAndFreq(); + assert topList.doc == doc; + return matches = freq >= minMatch; + } + + @Override + public float matchCost() { + return numScorers; + } + + }; + + @Override + public DocIdSetIterator iterator() { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; + } + + private void advanceAll(int target) throws IOException { + DisiWrapper top = subScorers.top(); + while (top.doc < target) { + top.doc = top.iterator.advance(target); + top = subScorers.updateTop(); + } + } + + private void setTopListAndFreq() { + topList = subScorers.topList(); + freq = 0; + for (DisiWrapper w = topList; w != null; w = w.next) { + freq++; + } + } + + private void setTopListAndFreqIfNecessary() throws IOException { + if (topList == null) { + advanceAll(doc); + setTopListAndFreq(); + } + } + + @Override + public int freq() throws IOException { + setTopListAndFreqIfNecessary(); + return freq; + } + + @Override + public float score() throws IOException { + // we need to know about all matches + setTopListAndFreqIfNecessary(); + double score = 0; + for (DisiWrapper w = topList; w != null; w = w.next) { + score += w.scorer.score(); + } + return (float) score; + } + + @Override + public int docID() { + return doc; + } + +} diff --git a/lucene/sandbox/src/test/org/apache/lucene/search/TestCoveringQuery.java b/lucene/sandbox/src/test/org/apache/lucene/search/TestCoveringQuery.java new file mode 100644 index 00000000000..29422896bfa --- /dev/null +++ b/lucene/sandbox/src/test/org/apache/lucene/search/TestCoveringQuery.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.MultiReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.LuceneTestCase; + +public class TestCoveringQuery extends LuceneTestCase { + + public void testEquals() { + TermQuery tq1 = new TermQuery(new Term("foo", "bar")); + TermQuery tq2 = new TermQuery(new Term("foo", "quux")); + LongValuesSource vs = LongValuesSource.fromLongField("field"); + + CoveringQuery q1 = new CoveringQuery(Arrays.asList(tq1, tq2), vs); + CoveringQuery q2 = new CoveringQuery(Arrays.asList(tq1, tq2), vs); + QueryUtils.checkEqual(q1, q2); + + // order does not matter + CoveringQuery q3 = new CoveringQuery(Arrays.asList(tq2, tq1), vs); + QueryUtils.checkEqual(q1, q3); + + // values source matters + CoveringQuery q4 = new CoveringQuery(Arrays.asList(tq2, tq1), LongValuesSource.fromLongField("other_field")); + QueryUtils.checkUnequal(q1, q4); + + // duplicates matter + CoveringQuery q5 = new CoveringQuery(Arrays.asList(tq1, tq1, tq2), vs); + CoveringQuery q6 = new CoveringQuery(Arrays.asList(tq1, tq2, tq2), vs); + QueryUtils.checkUnequal(q5, q6); + + // query matters + CoveringQuery q7 = new CoveringQuery(Arrays.asList(tq1), vs); + CoveringQuery q8 = new CoveringQuery(Arrays.asList(tq2), vs); + QueryUtils.checkUnequal(q7, q8); + } + + public void testRewrite() throws IOException { + PhraseQuery pq = new PhraseQuery("foo", "bar"); + TermQuery tq = new TermQuery(new Term("foo", "bar")); + LongValuesSource vs = LongValuesSource.fromIntField("field"); + assertEquals( + new CoveringQuery(Collections.singleton(tq), vs), + new CoveringQuery(Collections.singleton(pq), vs).rewrite(new MultiReader())); + } + + public void testToString() { + TermQuery tq1 = new TermQuery(new Term("foo", "bar")); + TermQuery tq2 = new TermQuery(new Term("foo", "quux")); + LongValuesSource vs = LongValuesSource.fromIntField("field"); + CoveringQuery q = new CoveringQuery(Arrays.asList(tq1, tq2), vs); + assertEquals("CoveringQuery(queries=[foo:bar, foo:quux], minimumNumberMatch=long(field))", q.toString()); + assertEquals("CoveringQuery(queries=[bar, quux], minimumNumberMatch=long(field))", q.toString("foo")); + } + + public void testRandom() throws IOException { + Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); + int numDocs = atLeast(200); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + if (random().nextBoolean()) { + doc.add(new StringField("field", "A", Store.NO)); + } + if (random().nextBoolean()) { + doc.add(new StringField("field", "B", Store.NO)); + } + if (random().nextDouble() > 0.9) { + doc.add(new StringField("field", "C", Store.NO)); + } + if (random().nextDouble() > 0.1) { + doc.add(new StringField("field", "D", Store.NO)); + } + doc.add(new NumericDocValuesField("min_match", random().nextInt(6))); + w.addDocument(doc); + } + + IndexReader r = DirectoryReader.open(w); + IndexSearcher searcher = new IndexSearcher(r); + w.close(); + + int iters = atLeast(10); + for (int iter = 0; iter < iters; ++iter) { + List queries = new ArrayList<>(); + if (random().nextBoolean()) { + queries.add(new TermQuery(new Term("field", "A"))); + } + if (random().nextBoolean()) { + queries.add(new TermQuery(new Term("field", "B"))); + } + if (random().nextBoolean()) { + queries.add(new TermQuery(new Term("field", "C"))); + } + if (random().nextBoolean()) { + queries.add(new TermQuery(new Term("field", "D"))); + } + if (random().nextBoolean()) { + queries.add(new TermQuery(new Term("field", "E"))); + } + + Query q = new CoveringQuery(queries, LongValuesSource.fromLongField("min_match")); + QueryUtils.check(random(), q, searcher); + + for (int i = 1; i < 4; ++i) { + BooleanQuery.Builder builder = new BooleanQuery.Builder() + .setMinimumNumberShouldMatch(i); + for (Query query : queries) { + builder.add(query, Occur.SHOULD); + } + Query q1 = builder.build(); + Query q2 = new CoveringQuery(queries, LongValuesSource.constant(i)); + assertEquals( + searcher.count(q1), + searcher.count(q2)); + } + + Query filtered = new BooleanQuery.Builder() + .add(q, Occur.MUST) + .add(new TermQuery(new Term("field", "A")), Occur.MUST) + .build(); + QueryUtils.check(random(), filtered, searcher); + } + + r.close(); + dir.close(); + } +}