Update the signature of vector script functions. (#48653)

Previously the functions accepted a doc values reference, whereas they now
accept the name of the vector field. Here's an example of how a vector function
was called before and after the change.

```
Before: cosineSimilarity(params.query_vector, doc['field'])
After:  cosineSimilarity(params.query_vector, 'field')
```

This seems more intuitive, since we don't allow direct access to vector doc
values and the the meaning of `doc['field']` is unclear.

The PR makes the following changes (broken into distinct commits):
* Add new function signatures of the form `function(params.query_vector,
'field')` and deprecates the old ones. Because Painless doesn't allow two
methods with the same name and number of arguments, we allow a generic `Object`
to be passed in to the function and decide on the behavior through an
`instanceof` check.
* Refactor the class bindings so that the document field is passed to the
constructor instead of the instance method. This allows us to avoid retrieving
the vector doc values on every function invocation, which gives a tiny speed-up
in benchmarks.

Note that this PR adds new signatures for the sparse vector functions too, even
though sparse vectors are deprecated. It seemed simplest to understand (for both
us and users) to keep everything symmetric between dense and sparse vectors.
This commit is contained in:
Julie Tibshirani 2019-10-29 15:46:05 -07:00 committed by GitHub
parent 25724c5c46
commit 89c65752dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 377 additions and 144 deletions

View File

@ -10,8 +10,8 @@ The following specialized API is available in the Score context.
==== Static Methods
The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values.
* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues)
* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues)
* double cosineSimilarity(List *, String)
* double cosineSimilaritySparse(Map *, String)
* double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
* double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
* double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier.
* double decayNumericExp(double *, double *, double *, double *, double)
* double decayNumericGauss(double *, double *, double *, double *, double)
* double decayNumericLinear(double *, double *, double *, double *, double)
* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues)
* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues)
* double dotProduct(List, String)
* double dotProductSparse(Map *, String)
* double randomScore(int *)
* double randomScore(int *, String *)
* double saturation(double, double)

View File

@ -29,3 +29,10 @@ We have not seen much interest in this experimental field type, and don't see
a clear use case as it's currently designed. If you have feedback or
suggestions around sparse vector functionality, please let us know through
GitHub or the 'discuss' forums.
[discrete]
==== Update to vector function signatures
The vector functions of the form `function(query, doc['field'])` are
deprecated, and the form `function(query, 'field')` should be used instead.
For example, `cosineSimilarity(query, doc['field'])` is replaced by
`cosineSimilarity(query, 'field')`.

View File

@ -68,7 +68,7 @@ GET my_index/_search
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2>
"source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2>
"params": {
"query_vector": [4, 3.4, -0.2] <3>
}
@ -105,7 +105,7 @@ GET my_index/_search
},
"script": {
"source": """
double value = dotProduct(params.query_vector, doc['my_dense_vector']);
double value = dotProduct(params.query_vector, 'my_dense_vector');
return sigmoid(1, Math.E, -value); <1>
""",
"params": {
@ -139,7 +139,7 @@ GET my_index/_search
}
},
"script": {
"source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1>
"source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1>
"params": {
"queryVector": [4, 3.4, -0.2]
}
@ -178,7 +178,7 @@ GET my_index/_search
}
},
"script": {
"source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))",
"source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))",
"params": {
"queryVector": [4, 3.4, -0.2]
}
@ -196,7 +196,7 @@ You can check if a document has a value for the field `my_vector` by
[source,js]
--------------------------------------------------
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, doc['my_vector'])"
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')"
--------------------------------------------------
// NOTCONSOLE
@ -262,7 +262,7 @@ GET my_sparse_index/_search
}
},
"script": {
"source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0",
"source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0",
"params": {
"query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
}
@ -294,7 +294,7 @@ GET my_sparse_index/_search
},
"script": {
"source": """
double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']);
double value = dotProductSparse(params.query_vector, 'my_sparse_vector');
return sigmoid(1, Math.E, -value);
""",
"params": {
@ -327,7 +327,7 @@ GET my_sparse_index/_search
}
},
"script": {
"source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))",
"source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))",
"params": {
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
}
@ -358,7 +358,7 @@ GET my_sparse_index/_search
}
},
"script": {
"source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))",
"source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))",
"params": {
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
}

View File

@ -119,7 +119,7 @@ public abstract class ScoreScript {
}
/** The doc lookup for the Lucene segment this script was created for. */
public final Map<String, ScriptDocValues<?>> getDoc() {
public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc();
}

View File

@ -1,6 +1,6 @@
setup:
- skip:
features: headers
features: [headers, warnings]
version: " - 7.2.99"
reason: "dense_vector dims parameter was added from 7.3"
@ -52,7 +52,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, doc['my_dense_vector'])"
source: "dotProduct(params.query_vector, 'my_dense_vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -82,7 +82,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -99,3 +99,26 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791}
---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}

View File

@ -53,7 +53,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l1norm(params.query_vector, doc['my_dense_vector'])"
source: "l1norm(params.query_vector, 'my_dense_vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -83,7 +83,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l2norm(params.query_vector, doc['my_dense_vector'])"
source: "l2norm(params.query_vector, 'my_dense_vector')"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

View File

@ -62,7 +62,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [10, 10, 10]
@ -81,7 +81,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [10.0, 10.0, 10.0]
@ -111,7 +111,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [1, 2, 3, 4]
- match: { error.root_cause.0.type: "script_exception" }
@ -125,7 +125,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, doc['my_dense_vector'])"
source: "dotProduct(params.query_vector, 'my_dense_vector')"
params:
query_vector: [1, 2, 3, 4]
- match: { error.root_cause.0.type: "script_exception" }
@ -161,7 +161,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [10.0, 10.0, 10.0]
- match: { error.root_cause.0.type: "script_exception" }
@ -177,7 +177,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')"
params:
query_vector: [10.0, 10.0, 10.0]
@ -208,7 +208,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])"
source: "dotProductSparse(params.query_vector, 'my_dense_vector')"
params:
query_vector: {"2": 0.5, "10" : 111.3, "3": 44}
- match: { error.root_cause.0.type: "script_exception" }

View File

@ -55,7 +55,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])"
source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -87,7 +87,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -104,3 +104,27 @@ setup:
- match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791}
---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The [sparse_vector] field type is deprecated and will be removed in 8.0.
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
params:
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}

View File

@ -55,7 +55,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])"
source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -88,7 +88,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])"
source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}

View File

@ -61,7 +61,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10}
@ -83,7 +83,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10.0}
@ -127,7 +127,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10.0}
- match: { error.root_cause.0.type: "script_exception" }
@ -145,7 +145,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10.0}
@ -194,7 +194,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555}
@ -229,7 +229,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProduct(params.query_vector, doc['my_sparse_vector'])"
source: "dotProduct(params.query_vector, 'my_sparse_vector')"
params:
query_vector: [0.5, 111]
- match: { error.root_cause.0.type: "script_exception" }
@ -272,7 +272,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])"
source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10, "5": 5}
@ -303,7 +303,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10, "5" : 5}
@ -333,7 +333,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])"
source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10, "5": 5}
@ -360,7 +360,7 @@ setup:
script_score:
query: {match_all: {} }
script:
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])"
source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
params:
query_vector: {"1": 10, "5": 5}

View File

@ -9,12 +9,16 @@ package org.elasticsearch.xpack.vectors.query;
import org.apache.logging.log4j.LogManager;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
@ -22,6 +26,10 @@ import java.util.Map;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
public class ScoreScriptUtils {
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class));
static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " +
"the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " +
"cosineSimilarity(query, 'field').";
//**************FUNCTIONS FOR DENSE VECTORS
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
@ -31,9 +39,12 @@ public class ScoreScriptUtils {
public static class DenseVectorFunction {
final ScoreScript scoreScript;
final float[] queryVector;
final VectorScriptDocValues.DenseVectorScriptDocValues docValues;
public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
this(scoreScript, queryVector, false);
public DenseVectorFunction(ScoreScript scoreScript,
List<Number> queryVector,
Object field) {
this(scoreScript, queryVector, field, false);
}
/**
@ -45,6 +56,7 @@ public class ScoreScriptUtils {
*/
public DenseVectorFunction(ScoreScript scoreScript,
List<Number> queryVector,
Object field,
boolean normalizeQuery) {
this.scoreScript = scoreScript;
@ -62,9 +74,28 @@ public class ScoreScriptUtils {
this.queryVector[dim] /= queryMagnitude;
}
}
if (field instanceof String) {
String fieldName = (String) field;
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof DenseVectorScriptDocValues) {
docValues = (DenseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}
}
public void validateDocVector(BytesRef vector) {
BytesRef getEncodedVector() {
try {
docValues.setNextDocId(scoreScript._getDocId());
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
// Validate the encoded vector's length.
BytesRef vector = docValues.getEncodedValue();
if (vector == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
@ -74,20 +105,21 @@ public class ScoreScriptUtils {
throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
queryVector.length + "] than the document vectors [" + vectorLength + "].");
}
return vector;
}
}
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
public static final class L1Norm extends DenseVectorFunction {
public L1Norm(ScoreScript scoreScript, List<Number> queryVector) {
super(scoreScript, queryVector);
public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector, field);
}
public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double l1norm() {
BytesRef vector = getEncodedVector();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double l1norm = 0;
for (float queryValue : queryVector) {
@ -100,13 +132,12 @@ public class ScoreScriptUtils {
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
public static final class L2Norm extends DenseVectorFunction {
public L2Norm(ScoreScript scoreScript, List<Number> queryVector) {
super(scoreScript, queryVector);
public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector, field);
}
public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double l2norm() {
BytesRef vector = getEncodedVector();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double l2norm = 0;
@ -121,13 +152,12 @@ public class ScoreScriptUtils {
// Calculate a dot product between a query's dense vector and documents' dense vectors
public static final class DotProduct extends DenseVectorFunction {
public DotProduct(ScoreScript scoreScript, List<Number> queryVector) {
super(scoreScript, queryVector);
public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector, field);
}
public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double dotProduct() {
BytesRef vector = getEncodedVector();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double dotProduct = 0;
@ -141,14 +171,12 @@ public class ScoreScriptUtils {
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
public static final class CosineSimilarity extends DenseVectorFunction {
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) {
super(scoreScript, queryVector, true);
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector, field, true);
}
public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double cosineSimilarity() {
BytesRef vector = getEncodedVector();
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double dotProduct = 0.0;
@ -176,15 +204,17 @@ public class ScoreScriptUtils {
// per script execution for all documents.
public static class SparseVectorFunction {
static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class));
final ScoreScript scoreScript;
final float[] queryValues;
final int[] queryDims;
final VectorScriptDocValues.SparseVectorScriptDocValues docValues;
// prepare queryVector once per script execution
// queryVector represents a map of dimensions to values
public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector) {
public SparseVectorFunction(ScoreScript scoreScript,
Map<String, Number> queryVector,
Object field) {
this.scoreScript = scoreScript;
//break vector into two arrays dims and values
int n = queryVector.size();
@ -203,28 +233,46 @@ public class ScoreScriptUtils {
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
sortSparseDimsFloatValues(queryDims, queryValues, n);
if (field instanceof String) {
String fieldName = (String) field;
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof SparseVectorScriptDocValues) {
docValues = (SparseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void validateDocVector(BytesRef vector) {
BytesRef getEncodedVector() {
try {
docValues.setNextDocId(scoreScript._getDocId());
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
BytesRef vector = docValues.getEncodedValue();
if (vector == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
return vector;
}
}
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
public static final class L1NormSparse extends SparseVectorFunction {
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector) {
super(scoreScript, queryVector);
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
}
public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double l1normSparse() {
BytesRef vector = getEncodedVector();
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
int queryIndex = 0;
int docIndex = 0;
double l1norm = 0;
@ -255,16 +303,15 @@ public class ScoreScriptUtils {
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
public static final class L2NormSparse extends SparseVectorFunction {
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
super(scoreScript, queryVector);
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
}
public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double l2normSparse() {
BytesRef vector = getEncodedVector();
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
int queryIndex = 0;
int docIndex = 0;
double l2norm = 0;
@ -298,16 +345,15 @@ public class ScoreScriptUtils {
// Calculate a dot product between a query's sparse vector and documents' sparse vectors
public static final class DotProductSparse extends SparseVectorFunction {
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
super(scoreScript, queryVector);
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
}
public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double dotProductSparse() {
BytesRef vector = getEncodedVector();
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
return intDotProductSparse(queryValues, queryDims, docValues, docDims);
}
}
@ -316,8 +362,8 @@ public class ScoreScriptUtils {
public static final class CosineSimilaritySparse extends SparseVectorFunction {
final double queryVectorMagnitude;
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
super(scoreScript, queryVector);
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector, docVector);
double dotProduct = 0;
for (int i = 0; i< queryDims.length; i++) {
dotProduct += queryValues[i] * queryValues[i];
@ -325,10 +371,8 @@ public class ScoreScriptUtils {
this.queryVectorMagnitude = Math.sqrt(dotProduct);
}
public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
BytesRef vector = dvs.getEncodedValue();
validateDocVector(vector);
public double cosineSimilaritySparse() {
BytesRef vector = getEncodedVector();
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);

View File

@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import {
}
static_import {
double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
}

View File

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -39,6 +40,66 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}
private void testDenseVectorFunctions(Version indexVersion) {
String field = "vector";
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field);
double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field);
double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector, field);
double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector, field);
double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
// test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
public void testDeprecatedDenseVectorFunctions() {
testDeprecatedDenseVectorFunctions(Version.V_7_4_0);
testDeprecatedDenseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedDenseVectorFunctions(Version indexVersion) {
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
@ -50,45 +111,53 @@ public class ScoreScriptUtilsTests extends ESTestCase {
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector);
double result = dotProduct.dotProduct(dvs);
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs);
double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector);
double result2 = cosineSimilarity.cosineSimilarity(dvs);
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs);
double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector);
double result3 = l1norm.l1norm(dvs);
L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs);
double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector);
double result4 = l2norm.l2norm(dvs);
L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs);
double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs));
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector);
e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs));
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector);
e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs));
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector);
e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs));
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
}
public void testSparseVectorFunctions() {
@ -97,12 +166,62 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}
private void testSparseVectorFunctions(Version indexVersion) {
String field = "vector";
int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testDeprecatedSparseVectorFunctions() {
testDeprecatedSparseVectorFunctions(Version.V_7_4_0);
testDeprecatedSparseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedSparseVectorFunctions(Version indexVersion) {
int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
@ -115,29 +234,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
double result = docProductSparse.dotProductSparse(dvs);
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
double result3 = l1Norm.l1normSparse(dvs);
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
double result4 = l2Norm.l2normSparse(dvs);
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions1() {
String field = "vector";
// Document vector's biggest dimension > query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
@ -145,8 +268,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
@ -157,29 +283,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
double result = docProductSparse.dotProductSparse(dvs);
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
double result3 = l1Norm.l1normSparse(dvs);
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
double result4 = l2Norm.l2normSparse(dvs);
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testSparseVectorMissingDimensions2() {
String field = "vector";
// Document vector's biggest dimension < query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
@ -187,8 +317,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
@ -199,25 +332,27 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
double result = docProductSparse.dotProductSparse(dvs);
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
double result3 = l1Norm.l1normSparse(dvs);
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
double result4 = l2Norm.l2normSparse(dvs);
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
}