BlendedTermQuery's equals method should consider boosts (#48193)

This changes the queries equals() method so that the boost factors for each term
are considered for the equality calculation. This means queries are only equal
if both their terms and associated boosts match. The ordering of the terms
doesn't matter as before, which is why we internally need to sort the terms and
boost for comparison on the first equals() call like before. Boosts that are
`null` are considered equal to boosts of 1.0f because topLevelQuery() will only
wrap into BoostQuery if boost is not null and different from 1f.

Closes #48184
This commit is contained in:
Christoph Büscher 2019-10-25 13:34:44 +02:00
parent 486794f24d
commit 3fb3397c12
2 changed files with 132 additions and 16 deletions

View File

@ -60,7 +60,6 @@ import java.util.Objects;
* which is the minimum number of documents the terms occurs in.
* </p>
*/
// TODO maybe contribute to Lucene
public abstract class BlendedTermQuery extends Query {
private final Term[] terms;
@ -246,36 +245,82 @@ public abstract class BlendedTermQuery extends Query {
return builder.toString();
}
private volatile Term[] equalTerms = null;
private class TermAndBoost implements Comparable<TermAndBoost> {
protected final Term term;
protected float boost;
private Term[] equalsTerms() {
if (terms.length == 1) {
return terms;
protected TermAndBoost(Term term, float boost) {
this.term = term;
this.boost = boost;
}
if (equalTerms == null) {
@Override
public int compareTo(TermAndBoost other) {
int compareTo = term.compareTo(other.term);
if (compareTo == 0) {
compareTo = Float.compare(boost, other.boost);
}
return compareTo;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o instanceof TermAndBoost == false) {
return false;
}
TermAndBoost that = (TermAndBoost) o;
return term.equals(that.term) && (Float.compare(boost, that.boost) == 0);
}
@Override
public int hashCode() {
return 31 * term.hashCode() + Float.hashCode(boost);
}
}
private volatile TermAndBoost[] equalTermsAndBoosts = null;
private TermAndBoost[] equalsTermsAndBoosts() {
if (equalTermsAndBoosts != null) {
return equalTermsAndBoosts;
}
if (terms.length == 1) {
float boost = (boosts != null ? boosts[0] : 1f);
equalTermsAndBoosts = new TermAndBoost[] {new TermAndBoost(terms[0], boost)};
} else {
// sort the terms to make sure equals and hashCode are consistent
// this should be a very small cost and equivalent to a HashSet but less object creation
final Term[] t = new Term[terms.length];
System.arraycopy(terms, 0, t, 0, terms.length);
ArrayUtil.timSort(t);
equalTerms = t;
equalTermsAndBoosts = new TermAndBoost[terms.length];
for (int i = 0; i < terms.length; i++) {
float boost = (boosts != null ? boosts[i] : 1f);
equalTermsAndBoosts[i] = new TermAndBoost(terms[i], boost);
}
ArrayUtil.timSort(equalTermsAndBoosts);
}
return equalTerms;
return equalTermsAndBoosts;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (sameClassAs(o) == false) return false;
if (this == o) {
return true;
}
if (sameClassAs(o) == false) {
return false;
}
BlendedTermQuery that = (BlendedTermQuery) o;
return Arrays.equals(equalsTerms(), that.equalsTerms());
return Arrays.equals(equalsTermsAndBoosts(), that.equalsTermsAndBoosts());
}
@Override
public int hashCode() {
return Objects.hash(classHash(), Arrays.hashCode(equalsTerms()));
return Objects.hash(classHash(), Arrays.hashCode(equalsTermsAndBoosts()));
}
/**

View File

@ -44,6 +44,9 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.Directory;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.EqualsHashCodeTestUtils;
import org.elasticsearch.test.EqualsHashCodeTestUtils.CopyFunction;
import org.elasticsearch.test.EqualsHashCodeTestUtils.MutateFunction;
import java.io.IOException;
import java.util.Arrays;
@ -257,4 +260,72 @@ public class BlendedTermQueryTests extends ESTestCase {
w.close();
dir.close();
}
public void testEqualsAndHash() {
String[] fields = new String[1 + random().nextInt(10)];
for (int i = 0; i < fields.length; i++) {
fields[i] = randomRealisticUnicodeOfLengthBetween(1, 10);
}
String term = randomRealisticUnicodeOfLengthBetween(1, 10);
Term[] terms = toTerms(fields, term);
float tieBreaker = randomFloat();
final float[] boosts;
if (randomBoolean()) {
boosts = new float[terms.length];
for (int i = 0; i < terms.length; i++) {
boosts[i] = randomFloat();
}
} else {
boosts = null;
}
BlendedTermQuery original = BlendedTermQuery.dismaxBlendedQuery(terms, boosts, tieBreaker);
CopyFunction<BlendedTermQuery> copyFunction = org -> {
Term[] termsCopy = new Term[terms.length];
System.arraycopy(terms, 0, termsCopy, 0, terms.length);
float[] boostsCopy = null;
if (boosts != null) {
boostsCopy = new float[boosts.length];
System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
}
if (randomBoolean() && terms.length > 1) {
// if we swap two elements, the resulting query should still be regarded as equal
int swapPos = randomIntBetween(1, terms.length - 1);
Term swpTerm = termsCopy[0];
termsCopy[0] = termsCopy[swapPos];
termsCopy[swapPos] = swpTerm;
if (boosts != null) {
float swpBoost = boostsCopy[0];
boostsCopy[0] = boostsCopy[swapPos];
boostsCopy[swapPos] = swpBoost;
}
}
return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boostsCopy, tieBreaker);
};
MutateFunction<BlendedTermQuery> mutateFunction = org -> {
if (randomBoolean()) {
Term[] termsCopy = new Term[terms.length];
System.arraycopy(terms, 0, termsCopy, 0, terms.length);
termsCopy[randomIntBetween(0, terms.length - 1)] = new Term(randomAlphaOfLength(10), randomAlphaOfLength(10));
return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boosts, tieBreaker);
} else {
float[] boostsCopy = null;
if (boosts != null) {
boostsCopy = new float[boosts.length];
System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
boostsCopy[randomIntBetween(0, terms.length - 1)] = randomFloat();
} else {
boostsCopy = new float[terms.length];
for (int i = 0; i < terms.length; i++) {
boostsCopy[i] = randomFloat();
}
}
return BlendedTermQuery.dismaxBlendedQuery(terms, boostsCopy, tieBreaker);
}
};
EqualsHashCodeTestUtils.checkEqualsAndHashCode(original, copyFunction, mutateFunction );
}
}