LUCENE-9322: Make sure to account for vectors in SortingCodecReader. (#2028)

This commit is contained in:
Julie Tibshirani 2020-10-26 05:00:26 -07:00 committed by GitHub
parent 4bf254158a
commit 37c7d156ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 1 deletions

View File

@ -29,6 +29,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.codecs.VectorReader;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
import org.apache.lucene.util.Bits; import org.apache.lucene.util.Bits;
@ -303,6 +304,32 @@ public final class SortingCodecReader extends FilterCodecReader {
}; };
} }
@Override
public VectorReader getVectorReader() {
VectorReader delegate = in.getVectorReader();
return new VectorReader() {
@Override
public void checkIntegrity() throws IOException {
delegate.checkIntegrity();
}
@Override
public VectorValues getVectorValues(String field) throws IOException {
return new VectorValuesWriter.SortingVectorValues(delegate.getVectorValues(field), docMap);
}
@Override
public void close() throws IOException {
delegate.close();
}
@Override
public long ramBytesUsed() {
return delegate.ramBytesUsed();
}
};
}
@Override @Override
public NormsProducer getNormsReader() { public NormsProducer getNormsReader() {
final NormsProducer delegate = in.getNormsReader(); final NormsProducer delegate = in.getNormsReader();

View File

@ -98,7 +98,7 @@ class VectorValuesWriter {
} }
} }
private static class SortingVectorValues extends VectorValues { static class SortingVectorValues extends VectorValues {
private final VectorValues delegate; private final VectorValues delegate;
private final VectorValues.RandomAccess randomAccess; private final VectorValues.RandomAccess randomAccess;

View File

@ -39,6 +39,7 @@ import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
import org.apache.lucene.document.VectorField;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
@ -126,6 +127,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
doc.add(new SortedDocValuesField("binary_sorted_dv", new BytesRef(Integer.toString(docId)))); doc.add(new SortedDocValuesField("binary_sorted_dv", new BytesRef(Integer.toString(docId))));
doc.add(new BinaryDocValuesField("binary_dv", new BytesRef(Integer.toString(docId)))); doc.add(new BinaryDocValuesField("binary_dv", new BytesRef(Integer.toString(docId))));
doc.add(new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId)))); doc.add(new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId))));
doc.add(new VectorField("vector", new float[] { (float) docId }));
doc.add(new NumericDocValuesField("foo", random().nextInt(20))); doc.add(new NumericDocValuesField("foo", random().nextInt(20)));
FieldType ft = new FieldType(StringField.TYPE_NOT_STORED); FieldType ft = new FieldType(StringField.TYPE_NOT_STORED);
@ -180,6 +182,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
SortedNumericDocValues sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv"); SortedNumericDocValues sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv");
SortedSetDocValues sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv"); SortedSetDocValues sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv");
SortedDocValues binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv"); SortedDocValues binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv");
VectorValues vectorValues = leaf.getVectorValues("vector");
NumericDocValues ids = leaf.getNumericDocValues("id"); NumericDocValues ids = leaf.getNumericDocValues("id");
long prevValue = -1; long prevValue = -1;
boolean usingAltIds = false; boolean usingAltIds = false;
@ -194,6 +197,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv"); sorted_numeric_dv = leaf.getSortedNumericDocValues("sorted_numeric_dv");
sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv"); sorted_set_dv = leaf.getSortedSetDocValues("sorted_set_dv");
binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv"); binary_sorted_dv = leaf.getSortedDocValues("binary_sorted_dv");
vectorValues = leaf.getVectorValues("vector");
prevValue = -1; prevValue = -1;
} }
assertTrue(prevValue + " < " + ids.longValue(), prevValue < ids.longValue()); assertTrue(prevValue + " < " + ids.longValue(), prevValue < ids.longValue());
@ -202,11 +206,17 @@ public class TestSortingCodecReader extends LuceneTestCase {
assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_numeric_dv.advanceExact(idNext));
assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext));
assertTrue(binary_sorted_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext));
assertEquals(idNext, vectorValues.advance(idNext));
assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue());
assertEquals(new BytesRef(ids.longValue() + ""), binary_sorted_dv.binaryValue()); assertEquals(new BytesRef(ids.longValue() + ""), binary_sorted_dv.binaryValue());
assertEquals(new BytesRef(ids.longValue() + ""), sorted_set_dv.lookupOrd(sorted_set_dv.nextOrd())); assertEquals(new BytesRef(ids.longValue() + ""), sorted_set_dv.lookupOrd(sorted_set_dv.nextOrd()));
assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(1, sorted_numeric_dv.docValueCount());
assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue());
float[] vectorValue = vectorValues.vectorValue();
assertEquals(1, vectorValue.length);
assertEquals((float) ids.longValue(), vectorValue[0], 0.001f);
Fields termVectors = leaf.getTermVectors(idNext); Fields termVectors = leaf.getTermVectors(idNext);
assertTrue(termVectors.terms("term_vectors").iterator().seekExact(new BytesRef("test" + ids.longValue()))); assertTrue(termVectors.terms("term_vectors").iterator().seekExact(new BytesRef("test" + ids.longValue())));
assertEquals(Long.toString(ids.longValue()), leaf.document(idNext).get("id")); assertEquals(Long.toString(ids.longValue()), leaf.document(idNext).get("id"));