From 37c7d156ab54ce9baae08bebb76eebe4da2e5b81 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Mon, 26 Oct 2020 05:00:26 -0700 Subject: [PATCH] LUCENE-9322: Make sure to account for vectors in SortingCodecReader. (#2028) --- .../lucene/index/SortingCodecReader.java | 27 +++++++++++++++++++ .../lucene/index/VectorValuesWriter.java | 2 +- .../lucene/index/TestSortingCodecReader.java | 10 +++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index f5831fd932d..fe2d5e193c4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -29,6 +29,7 @@ import org.apache.lucene.codecs.NormsProducer; import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.codecs.VectorReader; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.util.Bits; @@ -303,6 +304,32 @@ public final class SortingCodecReader extends FilterCodecReader { }; } + @Override + public VectorReader getVectorReader() { + VectorReader delegate = in.getVectorReader(); + return new VectorReader() { + @Override + public void checkIntegrity() throws IOException { + delegate.checkIntegrity(); + } + + @Override + public VectorValues getVectorValues(String field) throws IOException { + return new VectorValuesWriter.SortingVectorValues(delegate.getVectorValues(field), docMap); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public long ramBytesUsed() { + return delegate.ramBytesUsed(); + } + }; + } + @Override public NormsProducer getNormsReader() { final NormsProducer delegate = in.getNormsReader(); diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java index c4e97018b04..ae39b3a93f7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java @@ -98,7 +98,7 @@ class VectorValuesWriter { } } - private static class SortingVectorValues extends VectorValues { + static class SortingVectorValues extends VectorValues { private final VectorValues delegate; private final VectorValues.RandomAccess randomAccess; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index fb540481f3b..d2393eaa0fe 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -39,6 +39,7 @@ import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StringField; import org.apache.lucene.document.TextField; +import org.apache.lucene.document.VectorField; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Sort; @@ -126,6 +127,7 @@ public class TestSortingCodecReader extends LuceneTestCase { doc.add(new SortedDocValuesField("binary_sorted_dv", new BytesRef(Integer.toString(docId)))); doc.add(new BinaryDocValuesField("binary_dv", new BytesRef(Integer.toString(docId)))); doc.add(new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId)))); + doc.add(new VectorField("vector", new float[] { (float) docId })); doc.add(new NumericDocValuesField("foo", random().nextInt(20))); FieldType ft = new FieldType(StringField.TYPE_NOT_STORED); @@ -180,6 +182,7 @@ public class TestSortingCodecReader extends LuceneTestCase { SortedNumericDocValues sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv"); SortedSetDocValues sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv"); SortedDocValues binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv"); + VectorValues vectorValues = leaf.getVectorValues("vector"); NumericDocValues ids = leaf.getNumericDocValues("id"); long prevValue = -1; boolean usingAltIds = false; @@ -194,6 +197,7 @@ public class TestSortingCodecReader extends LuceneTestCase { sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv"); sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv"); binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv"); + vectorValues = leaf.getVectorValues("vector"); prevValue = -1; } assertTrue(prevValue + " < " + ids.longValue(), prevValue < ids.longValue()); @@ -202,11 +206,17 @@ public class TestSortingCodecReader extends LuceneTestCase { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); + assertEquals(idNext, vectorValues.advance(idNext)); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals(new BytesRef(ids.longValue() + ""), binary_sorted_dv.binaryValue()); assertEquals(new BytesRef(ids.longValue() + ""), sorted_set_dv.lookupOrd(sorted_set_dv.nextOrd())); assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); + + float[] vectorValue = vectorValues.vectorValue(); + assertEquals(1, vectorValue.length); + assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); + Fields termVectors = leaf.getTermVectors(idNext); assertTrue(termVectors.terms("term_vectors").iterator().seekExact(new BytesRef("test" + ids.longValue()))); assertEquals(Long.toString(ids.longValue()), leaf.document(idNext).get("id"));