diff --git a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
index f823f3a1426..1c775b01bee 100644
--- a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
+++ b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java
@@ -60,7 +60,6 @@ import java.util.Objects;
* which is the minimum number of documents the terms occurs in.
*
*/
-// TODO maybe contribute to Lucene
public abstract class BlendedTermQuery extends Query {
private final Term[] terms;
@@ -246,36 +245,82 @@ public abstract class BlendedTermQuery extends Query {
return builder.toString();
}
- private volatile Term[] equalTerms = null;
+ private class TermAndBoost implements Comparable {
+ protected final Term term;
+ protected float boost;
- private Term[] equalsTerms() {
- if (terms.length == 1) {
- return terms;
+ protected TermAndBoost(Term term, float boost) {
+ this.term = term;
+ this.boost = boost;
}
- if (equalTerms == null) {
+
+ @Override
+ public int compareTo(TermAndBoost other) {
+ int compareTo = term.compareTo(other.term);
+ if (compareTo == 0) {
+ compareTo = Float.compare(boost, other.boost);
+ }
+ return compareTo;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o instanceof TermAndBoost == false) {
+ return false;
+ }
+
+ TermAndBoost that = (TermAndBoost) o;
+ return term.equals(that.term) && (Float.compare(boost, that.boost) == 0);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * term.hashCode() + Float.hashCode(boost);
+ }
+ }
+
+ private volatile TermAndBoost[] equalTermsAndBoosts = null;
+
+ private TermAndBoost[] equalsTermsAndBoosts() {
+ if (equalTermsAndBoosts != null) {
+ return equalTermsAndBoosts;
+ }
+ if (terms.length == 1) {
+ float boost = (boosts != null ? boosts[0] : 1f);
+ equalTermsAndBoosts = new TermAndBoost[] {new TermAndBoost(terms[0], boost)};
+ } else {
// sort the terms to make sure equals and hashCode are consistent
// this should be a very small cost and equivalent to a HashSet but less object creation
- final Term[] t = new Term[terms.length];
- System.arraycopy(terms, 0, t, 0, terms.length);
- ArrayUtil.timSort(t);
- equalTerms = t;
+ equalTermsAndBoosts = new TermAndBoost[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ float boost = (boosts != null ? boosts[i] : 1f);
+ equalTermsAndBoosts[i] = new TermAndBoost(terms[i], boost);
+ }
+ ArrayUtil.timSort(equalTermsAndBoosts);
}
- return equalTerms;
-
+ return equalTermsAndBoosts;
}
@Override
public boolean equals(Object o) {
- if (this == o) return true;
- if (sameClassAs(o) == false) return false;
+ if (this == o) {
+ return true;
+ }
+ if (sameClassAs(o) == false) {
+ return false;
+ }
BlendedTermQuery that = (BlendedTermQuery) o;
- return Arrays.equals(equalsTerms(), that.equalsTerms());
+ return Arrays.equals(equalsTermsAndBoosts(), that.equalsTermsAndBoosts());
+
}
@Override
public int hashCode() {
- return Objects.hash(classHash(), Arrays.hashCode(equalsTerms()));
+ return Objects.hash(classHash(), Arrays.hashCode(equalsTermsAndBoosts()));
}
/**
diff --git a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
index 9d05e119cbb..ba7044b064e 100644
--- a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
+++ b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java
@@ -44,6 +44,9 @@ import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.store.Directory;
import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.EqualsHashCodeTestUtils;
+import org.elasticsearch.test.EqualsHashCodeTestUtils.CopyFunction;
+import org.elasticsearch.test.EqualsHashCodeTestUtils.MutateFunction;
import java.io.IOException;
import java.util.Arrays;
@@ -257,4 +260,72 @@ public class BlendedTermQueryTests extends ESTestCase {
w.close();
dir.close();
}
+
+ public void testEqualsAndHash() {
+ String[] fields = new String[1 + random().nextInt(10)];
+ for (int i = 0; i < fields.length; i++) {
+ fields[i] = randomRealisticUnicodeOfLengthBetween(1, 10);
+ }
+ String term = randomRealisticUnicodeOfLengthBetween(1, 10);
+ Term[] terms = toTerms(fields, term);
+ float tieBreaker = randomFloat();
+ final float[] boosts;
+ if (randomBoolean()) {
+ boosts = new float[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ boosts[i] = randomFloat();
+ }
+ } else {
+ boosts = null;
+ }
+
+ BlendedTermQuery original = BlendedTermQuery.dismaxBlendedQuery(terms, boosts, tieBreaker);
+ CopyFunction copyFunction = org -> {
+ Term[] termsCopy = new Term[terms.length];
+ System.arraycopy(terms, 0, termsCopy, 0, terms.length);
+
+ float[] boostsCopy = null;
+ if (boosts != null) {
+ boostsCopy = new float[boosts.length];
+ System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
+ }
+ if (randomBoolean() && terms.length > 1) {
+ // if we swap two elements, the resulting query should still be regarded as equal
+ int swapPos = randomIntBetween(1, terms.length - 1);
+
+ Term swpTerm = termsCopy[0];
+ termsCopy[0] = termsCopy[swapPos];
+ termsCopy[swapPos] = swpTerm;
+
+ if (boosts != null) {
+ float swpBoost = boostsCopy[0];
+ boostsCopy[0] = boostsCopy[swapPos];
+ boostsCopy[swapPos] = swpBoost;
+ }
+ }
+ return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boostsCopy, tieBreaker);
+ };
+ MutateFunction mutateFunction = org -> {
+ if (randomBoolean()) {
+ Term[] termsCopy = new Term[terms.length];
+ System.arraycopy(terms, 0, termsCopy, 0, terms.length);
+ termsCopy[randomIntBetween(0, terms.length - 1)] = new Term(randomAlphaOfLength(10), randomAlphaOfLength(10));
+ return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boosts, tieBreaker);
+ } else {
+ float[] boostsCopy = null;
+ if (boosts != null) {
+ boostsCopy = new float[boosts.length];
+ System.arraycopy(boosts, 0, boostsCopy, 0, terms.length);
+ boostsCopy[randomIntBetween(0, terms.length - 1)] = randomFloat();
+ } else {
+ boostsCopy = new float[terms.length];
+ for (int i = 0; i < terms.length; i++) {
+ boostsCopy[i] = randomFloat();
+ }
+ }
+ return BlendedTermQuery.dismaxBlendedQuery(terms, boostsCopy, tieBreaker);
+ }
+ };
+ EqualsHashCodeTestUtils.checkEqualsAndHashCode(original, copyFunction, mutateFunction );
+ }
}