LUCENE-9940: DisjunctionMaxQuery shouldn't depend on disjunct order for equals checks (#110)

DisjunctionMaxQuery stores its disjuncts in a Query[], and uses
Arrays.equals() for comparisons in its equals() implementation.
This means that the order in which disjuncts are added to the query
matters for equality checks.

This commit changes DMQ to instead store its disjuncts in a Multiset,
meaning that ordering no longer matters. The getDisjuncts()
method now returns a Collection<Query> rather than a List, and
some tests are changed to use query equality checks rather than
iterating over disjuncts and expecting a particular order.
This commit is contained in:
Alan Woodward 2021-04-29 09:47:55 +01:00 committed by GitHub
parent 043ed3a91f
commit f7a3587091
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 25 deletions

View File

@ -109,7 +109,6 @@ API Changes
only applicable for fields that are indexed with doc values only. (Mayya Sharipova, only applicable for fields that are indexed with doc values only. (Mayya Sharipova,
Adrien Grand, Simon Willnauer) Adrien Grand, Simon Willnauer)
Improvements Improvements
* LUCENE-9687: Hunspell support improvements: add API for spell-checking and suggestions, support compound words, * LUCENE-9687: Hunspell support improvements: add API for spell-checking and suggestions, support compound words,
@ -246,6 +245,9 @@ Bug fixes
* LUCENE-9930: The Ukrainian analyzer was reloading its dictionary for every new * LUCENE-9930: The Ukrainian analyzer was reloading its dictionary for every new
TokenStreamComponents, which could lead to memory leaks. (Alan Woodward) TokenStreamComponents, which could lead to memory leaks. (Alan Woodward)
* LUCENE-9940: The order of disjuncts in DisjunctionMaxQuery does not matter
for equality checks (Alan Woodward)
Changes in Backwards Compatibility Policy Changes in Backwards Compatibility Policy
* LUCENE-9904: regenerated UAX29URLEmailTokenizer and the corresponding analyzer with up-to-date top * LUCENE-9904: regenerated UAX29URLEmailTokenizer and the corresponding analyzer with up-to-date top

View File

@ -18,7 +18,6 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
@ -45,7 +44,7 @@ import org.apache.lucene.index.LeafReaderContext;
public final class DisjunctionMaxQuery extends Query implements Iterable<Query> { public final class DisjunctionMaxQuery extends Query implements Iterable<Query> {
/* The subqueries */ /* The subqueries */
private final Query[] disjuncts; private final Multiset<Query> disjuncts = new Multiset<>();
/* Multiple of the non-max disjunct scores added into our final score. Non-zero values support tie-breaking. */ /* Multiple of the non-max disjunct scores added into our final score. Non-zero values support tie-breaking. */
private final float tieBreakerMultiplier; private final float tieBreakerMultiplier;
@ -66,7 +65,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]"); throw new IllegalArgumentException("tieBreakerMultiplier must be in [0, 1]");
} }
this.tieBreakerMultiplier = tieBreakerMultiplier; this.tieBreakerMultiplier = tieBreakerMultiplier;
this.disjuncts = disjuncts.toArray(new Query[disjuncts.size()]); this.disjuncts.addAll(disjuncts);
} }
/** @return An {@code Iterator<Query>} over the disjuncts */ /** @return An {@code Iterator<Query>} over the disjuncts */
@ -76,8 +75,8 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
} }
/** @return the disjuncts. */ /** @return the disjuncts. */
public List<Query> getDisjuncts() { public Collection<Query> getDisjuncts() {
return Collections.unmodifiableList(Arrays.asList(disjuncts)); return Collections.unmodifiableCollection(disjuncts);
} }
/** @return tie breaker value for multiple matches. */ /** @return tie breaker value for multiple matches. */
@ -208,8 +207,8 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
*/ */
@Override @Override
public Query rewrite(IndexReader reader) throws IOException { public Query rewrite(IndexReader reader) throws IOException {
if (disjuncts.length == 1) { if (disjuncts.size() == 1) {
return disjuncts[0]; return disjuncts.iterator().next();
} }
if (tieBreakerMultiplier == 1.0f) { if (tieBreakerMultiplier == 1.0f) {
@ -254,14 +253,15 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
public String toString(String field) { public String toString(String field) {
StringBuilder buffer = new StringBuilder(); StringBuilder buffer = new StringBuilder();
buffer.append("("); buffer.append("(");
for (int i = 0; i < disjuncts.length; i++) { Iterator<Query> it = disjuncts.iterator();
Query subquery = disjuncts[i]; for (int i = 0; it.hasNext(); i++) {
Query subquery = it.next();
if (subquery instanceof BooleanQuery) { // wrap sub-bools in parens if (subquery instanceof BooleanQuery) { // wrap sub-bools in parens
buffer.append("("); buffer.append("(");
buffer.append(subquery.toString(field)); buffer.append(subquery.toString(field));
buffer.append(")"); buffer.append(")");
} else buffer.append(subquery.toString(field)); } else buffer.append(subquery.toString(field));
if (i != disjuncts.length - 1) buffer.append(" | "); if (i != disjuncts.size() - 1) buffer.append(" | ");
} }
buffer.append(")"); buffer.append(")");
if (tieBreakerMultiplier != 0.0f) { if (tieBreakerMultiplier != 0.0f) {
@ -285,7 +285,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
private boolean equalsTo(DisjunctionMaxQuery other) { private boolean equalsTo(DisjunctionMaxQuery other) {
return tieBreakerMultiplier == other.tieBreakerMultiplier return tieBreakerMultiplier == other.tieBreakerMultiplier
&& Arrays.equals(disjuncts, other.disjuncts); && Objects.equals(disjuncts, other.disjuncts);
} }
/** /**
@ -297,7 +297,7 @@ public final class DisjunctionMaxQuery extends Query implements Iterable<Query>
public int hashCode() { public int hashCode() {
int h = classHash(); int h = classHash();
h = 31 * h + Float.floatToIntBits(tieBreakerMultiplier); h = 31 * h + Float.floatToIntBits(tieBreakerMultiplier);
h = 31 * h + Arrays.hashCode(disjuncts); h = 31 * h + Objects.hashCode(disjuncts);
return h; return h;
} }
} }

View File

@ -496,11 +496,21 @@ public class TestDisjunctionMaxQuery extends LuceneTestCase {
Query sub2 = tq("hed", "elephant"); Query sub2 = tq("hed", "elephant");
DisjunctionMaxQuery q = new DisjunctionMaxQuery(Arrays.asList(sub1, sub2), 1.0f); DisjunctionMaxQuery q = new DisjunctionMaxQuery(Arrays.asList(sub1, sub2), 1.0f);
Query rewritten = s.rewrite(q); Query rewritten = s.rewrite(q);
assertTrue(rewritten instanceof BooleanQuery); Query expected =
BooleanQuery bq = (BooleanQuery) rewritten; new BooleanQuery.Builder()
assertEquals(bq.clauses().size(), 2); .add(sub1, BooleanClause.Occur.SHOULD)
assertEquals(bq.clauses().get(0), new BooleanClause(sub1, BooleanClause.Occur.SHOULD)); .add(sub2, BooleanClause.Occur.SHOULD)
assertEquals(bq.clauses().get(1), new BooleanClause(sub2, BooleanClause.Occur.SHOULD)); .build();
assertEquals(expected, rewritten);
}
public void testDisjunctOrderAndEquals() throws Exception {
// the order that disjuncts are provided in should not matter for equals() comparisons
Query sub1 = tq("hed", "albino");
Query sub2 = tq("hed", "elephant");
Query q1 = new DisjunctionMaxQuery(Arrays.asList(sub1, sub2), 1.0f);
Query q2 = new DisjunctionMaxQuery(Arrays.asList(sub2, sub1), 1.0f);
assertEquals(q1, q2);
} }
public void testRandomTopDocs() throws Exception { public void testRandomTopDocs() throws Exception {

View File

@ -18,16 +18,19 @@ package org.apache.lucene.queryparser.xml;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Arrays;
import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.MockAnalyzer; import org.apache.lucene.analysis.MockAnalyzer;
import org.apache.lucene.analysis.MockTokenFilter; import org.apache.lucene.analysis.MockTokenFilter;
import org.apache.lucene.analysis.MockTokenizer; import org.apache.lucene.analysis.MockTokenizer;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DisjunctionMaxQuery; import org.apache.lucene.search.DisjunctionMaxQuery;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.spans.SpanQuery; import org.apache.lucene.search.spans.SpanQuery;
import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase;
@ -99,13 +102,14 @@ public class TestCoreParser extends LuceneTestCase {
public void testDisjunctionMaxQueryXML() throws ParserException, IOException { public void testDisjunctionMaxQueryXML() throws ParserException, IOException {
Query q = parse("DisjunctionMaxQuery.xml"); Query q = parse("DisjunctionMaxQuery.xml");
assertTrue(q instanceof DisjunctionMaxQuery); Query expected =
DisjunctionMaxQuery d = (DisjunctionMaxQuery) q; new DisjunctionMaxQuery(
assertEquals(0.0f, d.getTieBreakerMultiplier(), 0.0001f); Arrays.asList(
assertEquals(2, d.getDisjuncts().size()); new TermQuery(new Term("a", "merger")),
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery) d.getDisjuncts().get(1); new DisjunctionMaxQuery(
assertEquals(0.3f, ndq.getTieBreakerMultiplier(), 0.0001f); Arrays.asList(new TermQuery(new Term("b", "verger"))), 0.3f)),
assertEquals(1, ndq.getDisjuncts().size()); 0.0f);
assertEquals(expected, q);
} }
public void testRangeQueryXML() throws ParserException, IOException { public void testRangeQueryXML() throws ParserException, IOException {