mirror of https://github.com/apache/lucene.git
Ability to compute vector similarity scores with DoubleValuesSource (#12548)
### Description This PR addresses the issue #12394. It adds an API **`similarityToQueryVector`** to `DoubleValuesSource` to compute vector similarity scores between the query vector and the `KnnByteVectorField`/`KnnFloatVectorField` for documents using the 2 new DVS implementations (`ByteVectorSimilarityValuesSource` for byte vectors and `FloatVectorSimilarityValuesSource` for float vectors). Below are the method signatures added to DVS in this PR: - `DoubleValues similarityToQueryVector(LeafReaderContext ctx, float[] queryVector, String vectorField)` *(uses ByteVectorSimilarityValuesSource)* - `DoubleValues similarityToQueryVector(LeafReaderContext ctx, byte[] queryVector, String vectorField)` *(uses FloatVectorSimilarityValuesSource)* Closes #12394
This commit is contained in:
parent
268dd54a86
commit
52dfe50e8f
|
@ -147,7 +147,8 @@ API Changes
|
||||||
|
|
||||||
New Features
|
New Features
|
||||||
---------------------
|
---------------------
|
||||||
(No changes)
|
* GITHUB#12548: Added similarityToQueryVector API to compute vector similarity scores
|
||||||
|
with DoubleValuesSource. (Shubham Chaudhary)
|
||||||
|
|
||||||
Improvements
|
Improvements
|
||||||
---------------------
|
---------------------
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Objects;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
|
||||||
|
* and the {@link org.apache.lucene.document.KnnByteVectorField} for documents.
|
||||||
|
*/
|
||||||
|
class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
|
||||||
|
private final byte[] queryVector;
|
||||||
|
|
||||||
|
public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
|
||||||
|
super(fieldName);
|
||||||
|
this.queryVector = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
|
||||||
|
final ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||||
|
VectorSimilarityFunction function =
|
||||||
|
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
|
||||||
|
return new DoubleValues() {
|
||||||
|
@Override
|
||||||
|
public double doubleValue() throws IOException {
|
||||||
|
return function.compare(queryVector, vectorValues.vectorValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean advanceExact(int doc) throws IOException {
|
||||||
|
return doc >= vectorValues.docID()
|
||||||
|
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(fieldName, Arrays.hashCode(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj) return true;
|
||||||
|
if (obj == null || getClass() != obj.getClass()) return false;
|
||||||
|
ByteVectorSimilarityValuesSource other = (ByteVectorSimilarityValuesSource) obj;
|
||||||
|
return Objects.equals(fieldName, other.fieldName)
|
||||||
|
&& Arrays.equals(queryVector, other.queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "ByteVectorSimilarityValuesSource(fieldName="
|
||||||
|
+ fieldName
|
||||||
|
+ " queryVector="
|
||||||
|
+ Arrays.toString(queryVector)
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -24,6 +24,7 @@ import java.util.function.LongToDoubleFunction;
|
||||||
import org.apache.lucene.index.DocValues;
|
import org.apache.lucene.index.DocValues;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.NumericDocValues;
|
import org.apache.lucene.index.NumericDocValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.search.comparators.DoubleComparator;
|
import org.apache.lucene.search.comparators.DoubleComparator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -172,6 +173,52 @@ public abstract class DoubleValuesSource implements SegmentCacheable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a DoubleValues instance for computing the vector similarity score per document against
|
||||||
|
* the byte query vector
|
||||||
|
*
|
||||||
|
* @param ctx the context for which to return the DoubleValues
|
||||||
|
* @param queryVector byte query vector
|
||||||
|
* @param vectorField knn byte field name
|
||||||
|
* @return DoubleValues instance
|
||||||
|
* @throws IOException if an {@link IOException} occurs
|
||||||
|
*/
|
||||||
|
public static DoubleValues similarityToQueryVector(
|
||||||
|
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
|
||||||
|
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
|
||||||
|
!= VectorEncoding.BYTE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Field "
|
||||||
|
+ vectorField
|
||||||
|
+ " does not have the expected vector encoding: "
|
||||||
|
+ VectorEncoding.BYTE);
|
||||||
|
}
|
||||||
|
return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a DoubleValues instance for computing the vector similarity score per document against
|
||||||
|
* the float query vector
|
||||||
|
*
|
||||||
|
* @param ctx the context for which to return the DoubleValues
|
||||||
|
* @param queryVector float query vector
|
||||||
|
* @param vectorField knn float field name
|
||||||
|
* @return DoubleValues instance
|
||||||
|
* @throws IOException if an {@link IOException} occurs
|
||||||
|
*/
|
||||||
|
public static DoubleValues similarityToQueryVector(
|
||||||
|
LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException {
|
||||||
|
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
|
||||||
|
!= VectorEncoding.FLOAT32) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Field "
|
||||||
|
+ vectorField
|
||||||
|
+ " does not have the expected vector encoding: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
|
}
|
||||||
|
return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a DoubleValuesSource that wraps a generic NumericDocValues field
|
* Creates a DoubleValuesSource that wraps a generic NumericDocValues field
|
||||||
*
|
*
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Objects;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
|
||||||
|
* and the {@link org.apache.lucene.document.KnnFloatVectorField} for documents.
|
||||||
|
*/
|
||||||
|
class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
|
||||||
|
|
||||||
|
private final float[] queryVector;
|
||||||
|
|
||||||
|
public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
|
||||||
|
super(fieldName);
|
||||||
|
this.queryVector = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
|
||||||
|
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
|
||||||
|
VectorSimilarityFunction function =
|
||||||
|
ctx.reader().getFieldInfos().fieldInfo(fieldName).getVectorSimilarityFunction();
|
||||||
|
return new DoubleValues() {
|
||||||
|
@Override
|
||||||
|
public double doubleValue() throws IOException {
|
||||||
|
return function.compare(queryVector, vectorValues.vectorValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean advanceExact(int doc) throws IOException {
|
||||||
|
return doc >= vectorValues.docID()
|
||||||
|
&& (vectorValues.docID() == doc || vectorValues.advance(doc) == doc);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return Objects.hash(fieldName, Arrays.hashCode(queryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if (this == obj) return true;
|
||||||
|
if (obj == null || getClass() != obj.getClass()) return false;
|
||||||
|
FloatVectorSimilarityValuesSource other = (FloatVectorSimilarityValuesSource) obj;
|
||||||
|
return Objects.equals(fieldName, other.fieldName)
|
||||||
|
&& Arrays.equals(queryVector, other.queryVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "FloatVectorSimilarityValuesSource(fieldName="
|
||||||
|
+ fieldName
|
||||||
|
+ " queryVector="
|
||||||
|
+ Arrays.toString(queryVector)
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An abstract class that provides the vector similarity scores between the query vector and the
|
||||||
|
* {@link org.apache.lucene.document.KnnFloatVectorField} or {@link
|
||||||
|
* org.apache.lucene.document.KnnByteVectorField} for documents.
|
||||||
|
*/
|
||||||
|
abstract class VectorSimilarityValuesSource extends DoubleValuesSource {
|
||||||
|
protected final String fieldName;
|
||||||
|
|
||||||
|
public VectorSimilarityValuesSource(String fieldName) {
|
||||||
|
this.fieldName = fieldName;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needsScores() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isCacheable(LeafReaderContext ctx) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,381 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
|
import org.apache.lucene.analysis.Analyzer;
|
||||||
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
|
import org.apache.lucene.document.KnnFloatVectorField;
|
||||||
|
import org.apache.lucene.document.SortedDocValuesField;
|
||||||
|
import org.apache.lucene.document.StringField;
|
||||||
|
import org.apache.lucene.index.IndexReader;
|
||||||
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||||
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.junit.AfterClass;
|
||||||
|
import org.junit.BeforeClass;
|
||||||
|
|
||||||
|
public class TestVectorSimilarityValuesSource extends LuceneTestCase {
|
||||||
|
private static Directory dir;
|
||||||
|
private static Analyzer analyzer;
|
||||||
|
private static IndexReader reader;
|
||||||
|
private static IndexSearcher searcher;
|
||||||
|
|
||||||
|
@BeforeClass
|
||||||
|
public static void beforeClass() throws Exception {
|
||||||
|
dir = newDirectory();
|
||||||
|
analyzer = new MockAnalyzer(random());
|
||||||
|
IndexWriterConfig iwConfig = newIndexWriterConfig(analyzer);
|
||||||
|
iwConfig.setMergePolicy(newLogMergePolicy());
|
||||||
|
RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwConfig);
|
||||||
|
|
||||||
|
Document document = new Document();
|
||||||
|
document.add(new StringField("id", "1", Field.Store.NO));
|
||||||
|
document.add(new SortedDocValuesField("id", new BytesRef("1")));
|
||||||
|
document.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
|
||||||
|
document.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField2",
|
||||||
|
new float[] {2.2f, -3.2f, -3.1f},
|
||||||
|
VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField3", new float[] {4.5f, 10.3f, -7.f}, VectorSimilarityFunction.COSINE));
|
||||||
|
document.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField4",
|
||||||
|
new float[] {-1.3f, 1.0f, 1.0f},
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
|
||||||
|
document.add(new KnnFloatVectorField("knnFloatField5", new float[] {-6.7f, -1.0f, -0.9f}));
|
||||||
|
document.add(new KnnByteVectorField("knnByteField1", new byte[] {106, 80, 127}));
|
||||||
|
document.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField2", new byte[] {4, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField3", new byte[] {-121, -64, -1}, VectorSimilarityFunction.COSINE));
|
||||||
|
document.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField4",
|
||||||
|
new byte[] {-127, 127, 127},
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
|
||||||
|
iw.addDocument(document);
|
||||||
|
|
||||||
|
Document document2 = new Document();
|
||||||
|
document2.add(new StringField("id", "2", Field.Store.NO));
|
||||||
|
document2.add(new SortedDocValuesField("id", new BytesRef("2")));
|
||||||
|
document2.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
|
||||||
|
document2.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField2",
|
||||||
|
new float[] {-5.2f, 8.7f, 3.1f},
|
||||||
|
VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document2.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField3", new float[] {0.2f, -3.2f, 3.1f}, VectorSimilarityFunction.COSINE));
|
||||||
|
document2.add(new KnnFloatVectorField("knnFloatField5", new float[] {2.f, 13.2f, 9.1f}));
|
||||||
|
document2.add(new KnnByteVectorField("knnByteField1", new byte[] {1, -2, -30}));
|
||||||
|
document2.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField2", new byte[] {40, 21, 3}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document2.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField3", new byte[] {9, 2, 3}, VectorSimilarityFunction.COSINE));
|
||||||
|
document2.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField4",
|
||||||
|
new byte[] {14, 29, 31},
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
|
||||||
|
iw.addDocument(document2);
|
||||||
|
|
||||||
|
Document document3 = new Document();
|
||||||
|
document3.add(new StringField("id", "3", Field.Store.NO));
|
||||||
|
document3.add(new SortedDocValuesField("id", new BytesRef("3")));
|
||||||
|
document3.add(new KnnFloatVectorField("knnFloatField1", new float[] {1.f, 2.f, 3.f}));
|
||||||
|
document3.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"knnFloatField2", new float[] {-8.f, 7.f, -6.f}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document3.add(new KnnFloatVectorField("knnFloatField5", new float[] {5.2f, 3.2f, 3.1f}));
|
||||||
|
document3.add(new KnnByteVectorField("knnByteField1", new byte[] {-128, 0, 127}));
|
||||||
|
document3.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField2", new byte[] {-1, -2, -3}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||||
|
document3.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField3", new byte[] {4, 2, 3}, VectorSimilarityFunction.COSINE));
|
||||||
|
document3.add(
|
||||||
|
new KnnByteVectorField(
|
||||||
|
"knnByteField4",
|
||||||
|
new byte[] {-4, -2, -128},
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT));
|
||||||
|
document3.add(new KnnByteVectorField("knnByteField5", new byte[] {-120, -2, 3}));
|
||||||
|
iw.addDocument(document3);
|
||||||
|
|
||||||
|
reader = iw.getReader();
|
||||||
|
searcher = newSearcher(reader);
|
||||||
|
iw.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterClass
|
||||||
|
public static void afterClass() throws Exception {
|
||||||
|
searcher = null;
|
||||||
|
IOUtils.close(reader, dir, analyzer);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testEuclideanSimilarityValuesSource() throws Exception {
|
||||||
|
float[] floatQueryVector = new float[] {9.f, 1.f, -10.f};
|
||||||
|
|
||||||
|
// Checks the computed similarity score between indexed vectors and query vector
|
||||||
|
// using DVS is correct by passing indexed and query vector in #compare
|
||||||
|
DoubleValues dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField5");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {-6.7f, -1.0f, -0.9f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {2.f, 13.2f, 9.1f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new float[] {5.2f, 3.2f, 3.1f}, floatQueryVector));
|
||||||
|
|
||||||
|
byte[] byteQueryVector = new byte[] {-128, 2, 127};
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField1");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new byte[] {106, 80, 127}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new byte[] {1, -2, -30}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new byte[] {-128, 0, 127}, byteQueryVector));
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField5");
|
||||||
|
assertFalse(dv.advanceExact(0));
|
||||||
|
assertFalse(dv.advanceExact(1));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.EUCLIDEAN.compare(
|
||||||
|
new byte[] {-120, -2, 3}, byteQueryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDotSimilarityValuesSource() throws Exception {
|
||||||
|
float[] floatQueryVector = new float[] {10.f, 1.f, -8.5f};
|
||||||
|
|
||||||
|
// Checks the computed similarity score between indexed vectors and query vector
|
||||||
|
// using DVS is correct by passing indexed and query vector in #compare
|
||||||
|
DoubleValues dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField2");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new float[] {2.2f, -3.2f, -3.1f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new float[] {-5.2f, 8.7f, 3.1f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new float[] {-8.f, 7.f, -6.f}, floatQueryVector));
|
||||||
|
|
||||||
|
byte[] byteQueryVector = new byte[] {-128, 2, 127};
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField2");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new byte[] {4, 2, 3}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new byte[] {40, 21, 3}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new byte[] {-1, -2, -3}, byteQueryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testCosineSimilarityValuesSource() throws Exception {
|
||||||
|
float[] floatQueryVector = new float[] {0.6f, -1.6f, 38.0f};
|
||||||
|
|
||||||
|
// Checks the computed similarity score between indexed vectors and query vector
|
||||||
|
// using DVS is correct by passing indexed and query vector in #compare
|
||||||
|
DoubleValues dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField3");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.COSINE.compare(
|
||||||
|
new float[] {4.5f, 10.3f, -7.f}, floatQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.COSINE.compare(
|
||||||
|
new float[] {0.2f, -3.2f, 3.1f}, floatQueryVector));
|
||||||
|
assertFalse(dv.advanceExact(2));
|
||||||
|
|
||||||
|
byte[] byteQueryVector = new byte[] {-10, 8, 0};
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField3");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.COSINE.compare(
|
||||||
|
new byte[] {-121, -64, -1}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.COSINE.compare(new byte[] {9, 2, 3}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.COSINE.compare(new byte[] {4, 2, 3}, byteQueryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testMaximumProductSimilarityValuesSource() throws Exception {
|
||||||
|
float[] floatQueryVector = new float[] {1.f, -6.f, -10.f};
|
||||||
|
|
||||||
|
// Checks the computed similarity score between indexed vectors and query vector
|
||||||
|
// using DVS is correct by passing indexed and query vector in #compare
|
||||||
|
DoubleValues dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField4");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
|
||||||
|
new float[] {-1.3f, 1.0f, 1.0f}, floatQueryVector));
|
||||||
|
assertFalse(dv.advanceExact(1));
|
||||||
|
assertFalse(dv.advanceExact(2));
|
||||||
|
|
||||||
|
byte[] byteQueryVector = new byte[] {-127, 127, 127};
|
||||||
|
|
||||||
|
dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnByteField4");
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(0)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
|
||||||
|
new byte[] {-127, 127, 127}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(1)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
|
||||||
|
new byte[] {14, 29, 31}, byteQueryVector));
|
||||||
|
assertTrue(
|
||||||
|
dv.advanceExact(2)
|
||||||
|
&& dv.doubleValue()
|
||||||
|
== VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
|
||||||
|
new byte[] {-4, -2, -128}, byteQueryVector));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testFailuresWithSimilarityValuesSource() throws Exception {
|
||||||
|
float[] floatQueryVector = new float[] {1.1f, 2.2f, 3.3f};
|
||||||
|
byte[] byteQueryVector = new byte[] {-10, 20, 30};
|
||||||
|
|
||||||
|
expectThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnByteField1"));
|
||||||
|
expectThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), byteQueryVector, "knnFloatField1"));
|
||||||
|
|
||||||
|
DoubleValues dv =
|
||||||
|
DoubleValuesSource.similarityToQueryVector(
|
||||||
|
searcher.reader.leaves().get(0), floatQueryVector, "knnFloatField1");
|
||||||
|
assertTrue(dv.advanceExact(0));
|
||||||
|
assertEquals(
|
||||||
|
dv.doubleValue(),
|
||||||
|
VectorSimilarityFunction.EUCLIDEAN.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector),
|
||||||
|
0.0);
|
||||||
|
assertNotEquals(
|
||||||
|
dv.doubleValue(),
|
||||||
|
VectorSimilarityFunction.DOT_PRODUCT.compare(
|
||||||
|
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
assertNotEquals(
|
||||||
|
dv.doubleValue(),
|
||||||
|
VectorSimilarityFunction.COSINE.compare(new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
assertNotEquals(
|
||||||
|
dv.doubleValue(),
|
||||||
|
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(
|
||||||
|
new float[] {1.f, 2.f, 3.f}, floatQueryVector));
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue