From c96ec0be678f82fce953f39271ef0b767382d984 Mon Sep 17 00:00:00 2001 From: Viswanath Kuchibhotla Date: Tue, 19 Nov 2024 01:14:26 +0530 Subject: [PATCH] Adding filter to the toString() method of KnnFloatVectorQuery (#13990) * Adding filter to toString() of KnnFloatVectorQuery when it's present (addresses https://github.com/apache/lucene/issues/13983) * addressing review comments * adding knnbytevectorquery * unit test improvements * tidy * adding changes entry for the bug fix --- lucene/CHANGES.txt | 2 ++ .../lucene/search/AbstractKnnVectorQuery.java | 2 +- .../apache/lucene/search/KnnByteVectorQuery.java | 9 ++++++++- .../lucene/search/KnnFloatVectorQuery.java | 9 ++++++++- .../lucene/search/TestKnnByteVectorQuery.java | 6 ++++++ .../lucene/search/TestKnnFloatVectorQuery.java | 6 ++++++ .../DiversifyingChildrenByteKnnVectorQuery.java | 9 ++++++++- .../DiversifyingChildrenFloatKnnVectorQuery.java | 9 ++++++++- .../TestParentBlockJoinByteKnnVectorQuery.java | 16 ++++++++++++++++ .../TestParentBlockJoinFloatKnnVectorQuery.java | 16 ++++++++++++++++ 10 files changed, 79 insertions(+), 5 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d360ab524ca..0e697a0dab3 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -80,6 +80,8 @@ Bug Fixes * GITHUB#13944: Ensure deterministic order of clauses for `DisjunctionMaxQuery#toString`. (Laurent Jakubina) * GITHUB#13841: Improve Tessellatorlogic when two holes share the same vertex with the polygon which was failing in valid polygons. (Ignacio Vera) +* GITHUB#13990: Added filter to the toString() method of Knn[Float|Byte]VectorQuery + and DiversifyingChildren[Float|Byte]KnnVectorQuery. (Viswanath Kuchibhotla) Build --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index df19de6cc8d..e9246a8b575 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -55,7 +55,7 @@ abstract class AbstractKnnVectorQuery extends Query { protected final String field; protected final int k; - private final Query filter; + protected final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { this.field = Objects.requireNonNull(field, "field"); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 9d6d71bc7a7..35144055830 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -111,7 +111,14 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { @Override public String toString(String field) { - return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]"; + StringBuilder buffer = new StringBuilder(); + buffer.append(getClass().getSimpleName() + ":"); + buffer.append(this.field + "[" + target[0] + ",...]"); + buffer.append("[" + k + "]"); + if (this.filter != null) { + buffer.append("[" + this.filter + "]"); + } + return buffer.toString(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 585893fa3c2..d2aaf4296ed 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -112,7 +112,14 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { @Override public String toString(String field) { - return getClass().getSimpleName() + ":" + this.field + "[" + target[0] + ",...][" + k + "]"; + StringBuilder buffer = new StringBuilder(); + buffer.append(getClass().getSimpleName() + ":"); + buffer.append(this.field + "[" + target[0] + ",...]"); + buffer.append("[" + k + "]"); + if (this.filter != null) { + buffer.append("[" + this.filter + "]"); + } + return buffer.toString(); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 4dc3d385b08..b45d6e8fb64 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.util.TestVectorUtil; @@ -78,6 +79,11 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { assertEquals("KnnByteVectorQuery:field[0,...][10]", query.toString("ignored")); assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0, 1}, 10, filter); + assertEquals("KnnByteVectorQuery:field[0,...][10][id:text]", query.toString("ignored")); } } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index cd66d6a27cd..feebe858c09 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; @@ -77,6 +78,11 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { assertEquals("KnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored")); assertDocScoreQueryToString(query.rewrite(newSearcher(reader))); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter); + assertEquals("KnnFloatVectorQuery:field[0.0,...][10][id:text]", query.toString("ignored")); } } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 456a885b49a..45cb8b9c88f 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -154,7 +154,14 @@ public class DiversifyingChildrenByteKnnVectorQuery extends KnnByteVectorQuery { @Override public String toString(String field) { - return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]"; + StringBuilder buffer = new StringBuilder(); + buffer.append(getClass().getSimpleName() + ":"); + buffer.append(this.field + "[" + query[0] + ",...]"); + buffer.append("[" + k + "]"); + if (this.filter != null) { + buffer.append("[" + this.filter + "]"); + } + return buffer.toString(); } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 7b5a656d141..9c44a2f7856 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -153,7 +153,14 @@ public class DiversifyingChildrenFloatKnnVectorQuery extends KnnFloatVectorQuery @Override public String toString(String field) { - return getClass().getSimpleName() + ":" + this.field + "[" + query[0] + ",...][" + k + "]"; + StringBuilder buffer = new StringBuilder(); + buffer.append(getClass().getSimpleName() + ":"); + buffer.append(this.field + "[" + query[0] + ",...]"); + buffer.append("[" + k + "]"); + if (this.filter != null) { + buffer.append("[" + this.filter + "]"); + } + return buffer.toString(); } @Override diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java index 6f773300a6d..6c1d461d4bf 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinByteKnnVectorQuery.java @@ -29,9 +29,11 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { @@ -81,6 +83,20 @@ public class TestParentBlockJoinByteKnnVectorQuery extends ParentBlockJoinKnnVec } } + public void testToString() { + // test without filter + Query query = getParentJoinKnnQuery("field", new float[] {0, 1}, null, 10, null); + assertEquals( + "DiversifyingChildrenByteKnnVectorQuery:field[0,...][10]", query.toString("ignored")); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getParentJoinKnnQuery("field", new float[] {0, 1}, filter, 10, null); + assertEquals( + "DiversifyingChildrenByteKnnVectorQuery:field[0,...][10][id:text]", + query.toString("ignored")); + } + private static byte[] fromFloat(float[] queryVector) { byte[] query = new byte[queryVector.length]; for (int i = 0; i < queryVector.length; i++) { diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java index 616c8fdb370..f15de3b57ee 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestParentBlockJoinFloatKnnVectorQuery.java @@ -29,9 +29,11 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; import org.apache.lucene.store.Directory; public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVectorQueryTestCase { @@ -110,6 +112,20 @@ public class TestParentBlockJoinFloatKnnVectorQuery extends ParentBlockJoinKnnVe } } + public void testToString() { + // test without filter + Query query = getParentJoinKnnQuery("field", new float[] {0, 1}, null, 10, null); + assertEquals( + "DiversifyingChildrenFloatKnnVectorQuery:field[0.0,...][10]", query.toString("ignored")); + + // test with filter + Query filter = new TermQuery(new Term("id", "text")); + query = getParentJoinKnnQuery("field", new float[] {0.0f, 1.0f}, filter, 10, null); + assertEquals( + "DiversifyingChildrenFloatKnnVectorQuery:field[0.0,...][10][id:text]", + query.toString("ignored")); + } + @Override Field getKnnVectorField(String name, float[] vector) { return new KnnFloatVectorField(name, vector);