LUCENE-10015: Remove VectorSimilarityFunction#NONE. (#219)

This commit is contained in:
Adrien Grand 2021-07-21 10:06:27 +02:00 committed by GitHub
parent acf45d8a31
commit 28ba8b7797
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 83 additions and 132 deletions

View File

@ -214,7 +214,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
pointIndexDimensionCount,
pointNumBytes,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField);
} catch (IllegalStateException e) {
throw new CorruptIndexException(

View File

@ -116,7 +116,7 @@ public class TestBlockWriter extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
}
}

View File

@ -203,7 +203,7 @@ public class TestSTBlockReader extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false);
}

View File

@ -200,16 +200,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
private FieldEntry readField(DataInput input) throws IOException {
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
switch (similarityFunction) {
case NONE:
return new FieldEntry(input, similarityFunction);
case DOT_PRODUCT:
case EUCLIDEAN:
return new HnswGraphFieldEntry(input, similarityFunction);
default:
throw new CorruptIndexException(
"Unknown vector similarity function: " + similarityFunction, input);
}
return new FieldEntry(input, similarityFunction);
}
@Override
@ -299,14 +290,9 @@ public final class Lucene90HnswVectorReader extends VectorReader {
}
private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException {
if (entry.similarityFunction != VectorSimilarityFunction.NONE) {
HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry;
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
return new IndexedKnnGraphReader(graphEntry, bytesSlice);
} else {
return KnnGraphValues.EMPTY;
}
IndexInput bytesSlice =
vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength);
return new IndexedKnnGraphReader(entry, bytesSlice);
}
@Override
@ -324,6 +310,7 @@ public final class Lucene90HnswVectorReader extends VectorReader {
final long indexDataOffset;
final long indexDataLength;
final int[] ordToDoc;
final long[] ordOffsets;
FieldEntry(DataInput input, VectorSimilarityFunction similarityFunction) throws IOException {
this.similarityFunction = similarityFunction;
@ -338,20 +325,6 @@ public final class Lucene90HnswVectorReader extends VectorReader {
int doc = input.readVInt();
ordToDoc[i] = doc;
}
}
int size() {
return ordToDoc.length;
}
}
private static class HnswGraphFieldEntry extends FieldEntry {
final long[] ordOffsets;
HnswGraphFieldEntry(DataInput input, VectorSimilarityFunction similarityFunction)
throws IOException {
super(input, similarityFunction);
ordOffsets = new long[size()];
long offset = 0;
for (int i = 0; i < ordOffsets.length; i++) {
@ -359,6 +332,10 @@ public final class Lucene90HnswVectorReader extends VectorReader {
ordOffsets[i] = offset;
}
}
int size() {
return ordToDoc.length;
}
}
/** Read the vector values from the index input. This supports both iterated and random access. */
@ -472,14 +449,14 @@ public final class Lucene90HnswVectorReader extends VectorReader {
/** Read the nearest-neighbors graph from the index input */
private static final class IndexedKnnGraphReader extends KnnGraphValues {
final HnswGraphFieldEntry entry;
final FieldEntry entry;
final IndexInput dataIn;
int arcCount;
int arcUpTo;
int arc;
IndexedKnnGraphReader(HnswGraphFieldEntry entry, IndexInput dataIn) {
IndexedKnnGraphReader(FieldEntry entry, IndexInput dataIn) {
this.entry = entry;
this.dataIn = dataIn;
}

View File

@ -127,21 +127,19 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
long[] offsets = new long[count];
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) {
if (vectors instanceof RandomAccessVectorValuesProducer) {
writeGraph(
vectorIndex,
(RandomAccessVectorValuesProducer) vectors,
fieldInfo.getVectorSimilarityFunction(),
vectorIndexOffset,
offsets,
count,
maxConn,
beamWidth);
} else {
throw new IllegalArgumentException(
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
}
if (vectors instanceof RandomAccessVectorValuesProducer) {
writeGraph(
vectorIndex,
(RandomAccessVectorValuesProducer) vectors,
fieldInfo.getVectorSimilarityFunction(),
vectorIndexOffset,
offsets,
count,
maxConn,
beamWidth);
} else {
throw new IllegalArgumentException(
"Indexing an HNSW graph requires a random access vector values, got " + vectors);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
@ -152,9 +150,7 @@ public final class Lucene90HnswVectorWriter extends VectorWriter {
vectorIndexLength,
count,
docIds);
if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) {
writeGraphOffsets(meta, offsets);
}
writeGraphOffsets(meta, offsets);
}
private void writeMeta(

View File

@ -18,6 +18,7 @@ package org.apache.lucene.document;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.analysis.Analyzer; // javadocs
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexOptions;
@ -43,7 +44,7 @@ public class FieldType implements IndexableFieldType {
private int indexDimensionCount;
private int dimensionNumBytes;
private int vectorDimension;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private Map<String, String> attributes;
/** Create a new mutable FieldType with all of the properties from <code>ref</code> */
@ -384,7 +385,7 @@ public class FieldType implements IndexableFieldType {
+ numDimensions);
}
this.vectorDimension = numDimensions;
this.vectorSimilarityFunction = distFunc;
this.vectorSimilarityFunction = Objects.requireNonNull(distFunc);
}
@Override

View File

@ -202,14 +202,6 @@ public final class FieldInfo {
throw new IllegalArgumentException(
"vectorDimension must be >=0; got " + vectorDimension + " (field: '" + name + "')");
}
if (vectorDimension == 0 && vectorSimilarityFunction != VectorSimilarityFunction.NONE) {
throw new IllegalArgumentException(
"vector similarity function must be NONE when dimension = 0; got "
+ vectorSimilarityFunction
+ " (field: '"
+ name
+ "')");
}
}
/**

View File

@ -486,7 +486,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
addOrGet(fi);
}
@ -567,7 +567,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
isSoftDeletesField);
}

View File

@ -1327,7 +1327,7 @@ final class IndexingChain implements Accountable {
private int pointIndexDimensionCount = 0;
private int pointNumBytes = 0;
private int vectorDimension = 0;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE;
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
private static String errMsg =
"Inconsistency of field data structures across documents for field ";
@ -1383,7 +1383,7 @@ final class IndexingChain implements Accountable {
}
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) {
if (vectorSimilarityFunction == VectorSimilarityFunction.NONE) {
if (vectorDimension == 0) {
this.vectorDimension = dimension;
this.vectorSimilarityFunction = similarityFunction;
} else {
@ -1402,7 +1402,7 @@ final class IndexingChain implements Accountable {
pointIndexDimensionCount = 0;
pointNumBytes = 0;
vectorDimension = 0;
vectorSimilarityFunction = VectorSimilarityFunction.NONE;
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
}
void assertSameSchema(FieldInfo fi) {

View File

@ -19,8 +19,6 @@ package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.squareDistance;
import org.apache.lucene.codecs.VectorReader;
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
@ -28,17 +26,21 @@ import org.apache.lucene.codecs.VectorReader;
*/
public enum VectorSimilarityFunction {
/**
* No similarity function is provided. Note: {@link VectorReader#search(String, float[], int)} is
* not supported for fields specifying this.
*/
NONE,
/** Euclidean distance */
EUCLIDEAN(true) {
@Override
public float compare(float[] v1, float[] v2) {
return squareDistance(v1, v2);
}
},
/** HNSW graph built using Euclidean distance */
EUCLIDEAN(true),
/** HNSW graph buit using dot product */
DOT_PRODUCT;
/** Dot product */
DOT_PRODUCT {
@Override
public float compare(float[] v1, float[] v2) {
return dotProduct(v1, v2);
}
};
/**
* If true, the scores associated with vector comparisons are in reverse order; that is, lower
@ -62,15 +64,5 @@ public enum VectorSimilarityFunction {
* @param v2 another vector, of the same dimension
* @return the value of the similarity function applied to the two vectors
*/
public float compare(float[] v1, float[] v2) {
switch (this) {
case EUCLIDEAN:
return squareDistance(v1, v2);
case DOT_PRODUCT:
return dotProduct(v1, v2);
case NONE:
default:
throw new IllegalStateException("Incomparable similarity function: " + this);
}
}
public abstract float compare(float[] v1, float[] v2);
}

View File

@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw;
import java.io.IOException;
import java.util.Locale;
import java.util.Objects;
import java.util.Random;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
@ -74,10 +75,7 @@ public final class HnswGraphBuilder {
long seed) {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
this.similarityFunction = similarityFunction;
if (similarityFunction == VectorSimilarityFunction.NONE) {
throw new IllegalStateException("No distance function");
}
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (maxConn <= 0) {
throw new IllegalArgumentException("maxConn must be positive");
}

View File

@ -90,9 +90,6 @@ public class TestPerFieldConsistency extends LuceneTestCase {
private static Field randomVectorField(Random random, String fieldName) {
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(random, VectorSimilarityFunction.values());
while (similarityFunction == VectorSimilarityFunction.NONE) {
similarityFunction = RandomPicks.randomFrom(random, VectorSimilarityFunction.values());
}
float[] values = new float[randomIntBetween(1, 10)];
for (int i = 0; i < values.length; i++) {
values[i] = randomFloat();

View File

@ -112,7 +112,7 @@ public class TestCodecs extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false));
}
this.terms = terms;

View File

@ -260,7 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false));
}
int idx =
@ -279,7 +279,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false));
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
@ -300,7 +300,7 @@ public class TestFieldInfos extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false));
assertEquals("Field numbers should reset after clear()", 0, idx);
}

View File

@ -63,7 +63,7 @@ public class TestFieldsReader extends LuceneTestCase {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
field.name().equals(softDeletesFieldName)));
}
dir = newDirectory();

View File

@ -114,7 +114,7 @@ public class TestIndexableField extends LuceneTestCase {
@Override
public VectorSimilarityFunction vectorSimilarityFunction() {
return VectorSimilarityFunction.NONE;
return VectorSimilarityFunction.EUCLIDEAN;
}
@Override

View File

@ -17,8 +17,6 @@
package org.apache.lucene.index;
import static org.apache.lucene.index.VectorSimilarityFunction.NONE;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
@ -198,7 +196,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
List<DocValuesFieldUpdates> updates = Arrays.asList(singleUpdate(docsDeleted, 10, true));
@ -235,7 +233,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
for (DocValuesFieldUpdates update : updates) {
deletes.onDocValuesUpdate(fieldInfo, update.iterator());
@ -297,7 +295,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
List<DocValuesFieldUpdates> updates = Arrays.asList(singleUpdate(docsDeleted, 3, true));
@ -364,7 +362,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
List<DocValuesFieldUpdates> updates =
Arrays.asList(singleUpdate(Arrays.asList(0, 1, DocIdSetIterator.NO_MORE_DOCS), 3, false));
@ -400,7 +398,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
0,
0,
0,
NONE,
VectorSimilarityFunction.EUCLIDEAN,
true);
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));
for (DocValuesFieldUpdates update : updates) {

View File

@ -98,7 +98,7 @@ public class TermVectorLeafReader extends LeafReader {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false);
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
}

View File

@ -412,7 +412,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false);
}
}

View File

@ -49,7 +49,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
@Override
protected void addRandomFields(Document doc) {
doc.add(new VectorField("v2", randomVector(30), VectorSimilarityFunction.NONE));
doc.add(new VectorField("v2", randomVector(30), VectorSimilarityFunction.EUCLIDEAN));
}
public void testFieldConstructor() {
@ -461,11 +461,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Exception e =
expectThrows(
IllegalArgumentException.class,
() -> doc.add(new VectorField("f", new float[0], VectorSimilarityFunction.NONE)));
() ->
doc.add(new VectorField("f", new float[0], VectorSimilarityFunction.EUCLIDEAN)));
assertEquals("cannot index an empty vector", e.getMessage());
Document doc2 = new Document();
doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.NONE));
doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc2);
}
}
@ -509,7 +510,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
}
public void testInvalidVectorFieldUsage() {
VectorField field = new VectorField("field", new float[2], VectorSimilarityFunction.NONE);
VectorField field = new VectorField("field", new float[2], VectorSimilarityFunction.EUCLIDEAN);
expectThrows(IllegalArgumentException.class, () -> field.setIntValue(14));
@ -681,7 +682,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Document doc = new Document();
float[] v = new float[] {1};
doc.add(new VectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN));
doc.add(new VectorField("field2", new float[] {1, 2, 3}, VectorSimilarityFunction.NONE));
doc.add(new VectorField("field2", new float[] {1, 2, 3}, VectorSimilarityFunction.EUCLIDEAN));
iw.addDocument(doc);
v[0] = 2;
iw.addDocument(doc);
@ -748,9 +749,9 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
if (random().nextBoolean() && values[i] != null) {
// sometimes use a shared scratch array
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
add(iw, fieldName, i, scratch, VectorSimilarityFunction.NONE);
add(iw, fieldName, i, scratch, VectorSimilarityFunction.EUCLIDEAN);
} else {
add(iw, fieldName, i, values[i], VectorSimilarityFunction.NONE);
add(iw, fieldName, i, values[i], VectorSimilarityFunction.EUCLIDEAN);
}
if (random().nextInt(10) == 2) {
// sometimes delete a random document
@ -865,7 +866,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
private void add(IndexWriter iw, String field, int id, int sortkey, float[] vector)
throws IOException {
add(iw, field, id, sortkey, vector, VectorSimilarityFunction.NONE);
add(iw, field, id, sortkey, vector, VectorSimilarityFunction.EUCLIDEAN);
}
private void add(
@ -900,10 +901,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new VectorField("v1", randomVector(3), VectorSimilarityFunction.NONE));
doc.add(new VectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
doc.add(new VectorField("v2", randomVector(3), VectorSimilarityFunction.NONE));
doc.add(new VectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
@ -924,10 +925,9 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
public void testSimilarityFunctionIdentifiers() {
// make sure we don't accidentally mess up similarity function identifiers by re-ordering their
// enumerators
assertEquals(0, VectorSimilarityFunction.NONE.ordinal());
assertEquals(1, VectorSimilarityFunction.EUCLIDEAN.ordinal());
assertEquals(2, VectorSimilarityFunction.DOT_PRODUCT.ordinal());
assertEquals(3, VectorSimilarityFunction.values().length);
assertEquals(0, VectorSimilarityFunction.EUCLIDEAN.ordinal());
assertEquals(1, VectorSimilarityFunction.DOT_PRODUCT.ordinal());
assertEquals(2, VectorSimilarityFunction.values().length);
}
public void testAdvance() throws Exception {
@ -939,7 +939,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa
Document doc = new Document();
// randomly add a vector field
if (random().nextInt(4) == 3) {
doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.NONE));
doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.EUCLIDEAN));
}
w.addDocument(doc);
}

View File

@ -140,7 +140,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false);
fieldUpto++;
@ -711,7 +711,7 @@ public class RandomPostingsTester {
0,
0,
0,
VectorSimilarityFunction.NONE,
VectorSimilarityFunction.EUCLIDEAN,
false);
}