Update the signature of vector script functions. (#48653)

Previously the functions accepted a doc values reference, whereas they now
accept the name of the vector field. Here's an example of how a vector function
was called before and after the change.

```
Before: cosineSimilarity(params.query_vector, doc['field'])
After:  cosineSimilarity(params.query_vector, 'field')
```

This seems more intuitive, since we don't allow direct access to vector doc
values and the the meaning of `doc['field']` is unclear.

The PR makes the following changes (broken into distinct commits):
* Add new function signatures of the form `function(params.query_vector,
'field')` and deprecates the old ones. Because Painless doesn't allow two
methods with the same name and number of arguments, we allow a generic `Object`
to be passed in to the function and decide on the behavior through an
`instanceof` check.
* Refactor the class bindings so that the document field is passed to the
constructor instead of the instance method. This allows us to avoid retrieving
the vector doc values on every function invocation, which gives a tiny speed-up
in benchmarks.

Note that this PR adds new signatures for the sparse vector functions too, even
though sparse vectors are deprecated. It seemed simplest to understand (for both
us and users) to keep everything symmetric between dense and sparse vectors.
This commit is contained in:
Julie Tibshirani 2019-10-29 15:46:05 -07:00 committed by GitHub
parent 25724c5c46
commit 89c65752dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 377 additions and 144 deletions

View File

@ -10,8 +10,8 @@ The following specialized API is available in the Score context.
==== Static Methods ==== Static Methods
The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values. The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values.
* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues) * double cosineSimilarity(List *, String)
* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) * double cosineSimilaritySparse(Map *, String)
* double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
* double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
* double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime) * double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier.
* double decayNumericExp(double *, double *, double *, double *, double) * double decayNumericExp(double *, double *, double *, double *, double)
* double decayNumericGauss(double *, double *, double *, double *, double) * double decayNumericGauss(double *, double *, double *, double *, double)
* double decayNumericLinear(double *, double *, double *, double *, double) * double decayNumericLinear(double *, double *, double *, double *, double)
* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues) * double dotProduct(List, String)
* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues) * double dotProductSparse(Map *, String)
* double randomScore(int *) * double randomScore(int *)
* double randomScore(int *, String *) * double randomScore(int *, String *)
* double saturation(double, double) * double saturation(double, double)

View File

@ -29,3 +29,10 @@ We have not seen much interest in this experimental field type, and don't see
a clear use case as it's currently designed. If you have feedback or a clear use case as it's currently designed. If you have feedback or
suggestions around sparse vector functionality, please let us know through suggestions around sparse vector functionality, please let us know through
GitHub or the 'discuss' forums. GitHub or the 'discuss' forums.
[discrete]
==== Update to vector function signatures
The vector functions of the form `function(query, doc['field'])` are
deprecated, and the form `function(query, 'field')` should be used instead.
For example, `cosineSimilarity(query, doc['field'])` is replaced by
`cosineSimilarity(query, 'field')`.

View File

@ -68,7 +68,7 @@ GET my_index/_search
} }
}, },
"script": { "script": {
"source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2> "source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2>
"params": { "params": {
"query_vector": [4, 3.4, -0.2] <3> "query_vector": [4, 3.4, -0.2] <3>
} }
@ -105,7 +105,7 @@ GET my_index/_search
}, },
"script": { "script": {
"source": """ "source": """
double value = dotProduct(params.query_vector, doc['my_dense_vector']); double value = dotProduct(params.query_vector, 'my_dense_vector');
return sigmoid(1, Math.E, -value); <1> return sigmoid(1, Math.E, -value); <1>
""", """,
"params": { "params": {
@ -139,7 +139,7 @@ GET my_index/_search
} }
}, },
"script": { "script": {
"source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1> "source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1>
"params": { "params": {
"queryVector": [4, 3.4, -0.2] "queryVector": [4, 3.4, -0.2]
} }
@ -178,7 +178,7 @@ GET my_index/_search
} }
}, },
"script": { "script": {
"source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))", "source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))",
"params": { "params": {
"queryVector": [4, 3.4, -0.2] "queryVector": [4, 3.4, -0.2]
} }
@ -196,7 +196,7 @@ You can check if a document has a value for the field `my_vector` by
[source,js] [source,js]
-------------------------------------------------- --------------------------------------------------
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, doc['my_vector'])" "source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')"
-------------------------------------------------- --------------------------------------------------
// NOTCONSOLE // NOTCONSOLE
@ -262,7 +262,7 @@ GET my_sparse_index/_search
} }
}, },
"script": { "script": {
"source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0", "source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0",
"params": { "params": {
"query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} "query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
} }
@ -294,7 +294,7 @@ GET my_sparse_index/_search
}, },
"script": { "script": {
"source": """ "source": """
double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']); double value = dotProductSparse(params.query_vector, 'my_sparse_vector');
return sigmoid(1, Math.E, -value); return sigmoid(1, Math.E, -value);
""", """,
"params": { "params": {
@ -327,7 +327,7 @@ GET my_sparse_index/_search
} }
}, },
"script": { "script": {
"source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))", "source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))",
"params": { "params": {
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
} }
@ -358,7 +358,7 @@ GET my_sparse_index/_search
} }
}, },
"script": { "script": {
"source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))", "source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))",
"params": { "params": {
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0} "queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
} }

View File

@ -119,7 +119,7 @@ public abstract class ScoreScript {
} }
/** The doc lookup for the Lucene segment this script was created for. */ /** The doc lookup for the Lucene segment this script was created for. */
public final Map<String, ScriptDocValues<?>> getDoc() { public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc(); return leafLookup.doc();
} }

View File

@ -1,6 +1,6 @@
setup: setup:
- skip: - skip:
features: headers features: [headers, warnings]
version: " - 7.2.99" version: " - 7.2.99"
reason: "dense_vector dims parameter was added from 7.3" reason: "dense_vector dims parameter was added from 7.3"
@ -52,7 +52,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProduct(params.query_vector, doc['my_dense_vector'])" source: "dotProduct(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -82,7 +82,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -99,3 +99,26 @@ setup:
- match: {hits.hits.2._id: "1"} - match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78} - gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791} - lte: {hits.hits.2._score: 0.791}
---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}

View File

@ -53,7 +53,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l1norm(params.query_vector, doc['my_dense_vector'])" source: "l1norm(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
@ -83,7 +83,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l2norm(params.query_vector, doc['my_dense_vector'])" source: "l2norm(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0] query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]

View File

@ -62,7 +62,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [10, 10, 10] query_vector: [10, 10, 10]
@ -81,7 +81,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [10.0, 10.0, 10.0] query_vector: [10.0, 10.0, 10.0]
@ -111,7 +111,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [1, 2, 3, 4] query_vector: [1, 2, 3, 4]
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }
@ -125,7 +125,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProduct(params.query_vector, doc['my_dense_vector'])" source: "dotProduct(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [1, 2, 3, 4] query_vector: [1, 2, 3, 4]
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }
@ -161,7 +161,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [10.0, 10.0, 10.0] query_vector: [10.0, 10.0, 10.0]
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }
@ -177,7 +177,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])" source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: [10.0, 10.0, 10.0] query_vector: [10.0, 10.0, 10.0]
@ -208,7 +208,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])" source: "dotProductSparse(params.query_vector, 'my_dense_vector')"
params: params:
query_vector: {"2": 0.5, "10" : 111.3, "3": 44} query_vector: {"2": 0.5, "10" : 111.3, "3": 44}
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }

View File

@ -55,7 +55,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -87,7 +87,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -104,3 +104,27 @@ setup:
- match: {hits.hits.2._id: "1"} - match: {hits.hits.2._id: "1"}
- gte: {hits.hits.2._score: 0.78} - gte: {hits.hits.2._score: 0.78}
- lte: {hits.hits.2._score: 0.791} - lte: {hits.hits.2._score: 0.791}
---
"Deprecated function signature":
- do:
headers:
Content-Type: application/json
warnings:
- The [sparse_vector] field type is deprecated and will be removed in 8.0.
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
search:
rest_total_hits_as_int: true
body:
query:
script_score:
query: {match_all: {} }
script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
params:
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
- match: {hits.total: 3}
- match: {hits.hits.0._id: "3"}
- match: {hits.hits.1._id: "2"}
- match: {hits.hits.2._id: "1"}

View File

@ -55,7 +55,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
@ -88,7 +88,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0} query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}

View File

@ -61,7 +61,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10} query_vector: {"1": 10}
@ -83,7 +83,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10.0} query_vector: {"1": 10.0}
@ -127,7 +127,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10.0} query_vector: {"1": 10.0}
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }
@ -145,7 +145,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10.0} query_vector: {"1": 10.0}
@ -194,7 +194,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555} query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555}
@ -229,7 +229,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProduct(params.query_vector, doc['my_sparse_vector'])" source: "dotProduct(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: [0.5, 111] query_vector: [0.5, 111]
- match: { error.root_cause.0.type: "script_exception" } - match: { error.root_cause.0.type: "script_exception" }
@ -272,7 +272,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])" source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10, "5": 5} query_vector: {"1": 10, "5": 5}
@ -303,7 +303,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])" source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10, "5" : 5} query_vector: {"1": 10, "5" : 5}
@ -333,7 +333,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])" source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10, "5": 5} query_vector: {"1": 10, "5": 5}
@ -360,7 +360,7 @@ setup:
script_score: script_score:
query: {match_all: {} } query: {match_all: {} }
script: script:
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])" source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
params: params:
query_vector: {"1": 10, "5": 5} query_vector: {"1": 10, "5": 5}

View File

@ -9,12 +9,16 @@ package org.elasticsearch.xpack.vectors.query;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version; import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.script.ScoreScript; import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper; import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder; import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues;
import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -22,6 +26,10 @@ import java.util.Map;
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues; import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
public class ScoreScriptUtils { public class ScoreScriptUtils {
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class));
static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " +
"the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " +
"cosineSimilarity(query, 'field').";
//**************FUNCTIONS FOR DENSE VECTORS //**************FUNCTIONS FOR DENSE VECTORS
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings. // Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
@ -31,9 +39,12 @@ public class ScoreScriptUtils {
public static class DenseVectorFunction { public static class DenseVectorFunction {
final ScoreScript scoreScript; final ScoreScript scoreScript;
final float[] queryVector; final float[] queryVector;
final VectorScriptDocValues.DenseVectorScriptDocValues docValues;
public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) { public DenseVectorFunction(ScoreScript scoreScript,
this(scoreScript, queryVector, false); List<Number> queryVector,
Object field) {
this(scoreScript, queryVector, field, false);
} }
/** /**
@ -45,6 +56,7 @@ public class ScoreScriptUtils {
*/ */
public DenseVectorFunction(ScoreScript scoreScript, public DenseVectorFunction(ScoreScript scoreScript,
List<Number> queryVector, List<Number> queryVector,
Object field,
boolean normalizeQuery) { boolean normalizeQuery) {
this.scoreScript = scoreScript; this.scoreScript = scoreScript;
@ -62,9 +74,28 @@ public class ScoreScriptUtils {
this.queryVector[dim] /= queryMagnitude; this.queryVector[dim] /= queryMagnitude;
} }
} }
if (field instanceof String) {
String fieldName = (String) field;
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof DenseVectorScriptDocValues) {
docValues = (DenseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}
} }
public void validateDocVector(BytesRef vector) { BytesRef getEncodedVector() {
try {
docValues.setNextDocId(scoreScript._getDocId());
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
// Validate the encoded vector's length.
BytesRef vector = docValues.getEncodedValue();
if (vector == null) { if (vector == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
} }
@ -74,20 +105,21 @@ public class ScoreScriptUtils {
throw new IllegalArgumentException("The query vector has a different number of dimensions [" + throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
queryVector.length + "] than the document vectors [" + vectorLength + "]."); queryVector.length + "] than the document vectors [" + vectorLength + "].");
} }
return vector;
} }
} }
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors // Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
public static final class L1Norm extends DenseVectorFunction { public static final class L1Norm extends DenseVectorFunction {
public L1Norm(ScoreScript scoreScript, List<Number> queryVector) { public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector); super(scoreScript, queryVector, field);
} }
public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { public double l1norm() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double l1norm = 0; double l1norm = 0;
for (float queryValue : queryVector) { for (float queryValue : queryVector) {
@ -100,13 +132,12 @@ public class ScoreScriptUtils {
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors // Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
public static final class L2Norm extends DenseVectorFunction { public static final class L2Norm extends DenseVectorFunction {
public L2Norm(ScoreScript scoreScript, List<Number> queryVector) { public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector); super(scoreScript, queryVector, field);
} }
public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { public double l2norm() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double l2norm = 0; double l2norm = 0;
@ -121,13 +152,12 @@ public class ScoreScriptUtils {
// Calculate a dot product between a query's dense vector and documents' dense vectors // Calculate a dot product between a query's dense vector and documents' dense vectors
public static final class DotProduct extends DenseVectorFunction { public static final class DotProduct extends DenseVectorFunction {
public DotProduct(ScoreScript scoreScript, List<Number> queryVector) { public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector); super(scoreScript, queryVector, field);
} }
public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){ public double dotProduct() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double dotProduct = 0; double dotProduct = 0;
@ -141,14 +171,12 @@ public class ScoreScriptUtils {
// Calculate cosine similarity between a query's dense vector and documents' dense vectors // Calculate cosine similarity between a query's dense vector and documents' dense vectors
public static final class CosineSimilarity extends DenseVectorFunction { public static final class CosineSimilarity extends DenseVectorFunction {
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) { public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
super(scoreScript, queryVector, true); super(scoreScript, queryVector, field, true);
} }
public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) { public double cosineSimilarity() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length); ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
double dotProduct = 0.0; double dotProduct = 0.0;
@ -176,15 +204,17 @@ public class ScoreScriptUtils {
// per script execution for all documents. // per script execution for all documents.
public static class SparseVectorFunction { public static class SparseVectorFunction {
static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class));
final ScoreScript scoreScript; final ScoreScript scoreScript;
final float[] queryValues; final float[] queryValues;
final int[] queryDims; final int[] queryDims;
final VectorScriptDocValues.SparseVectorScriptDocValues docValues;
// prepare queryVector once per script execution // prepare queryVector once per script execution
// queryVector represents a map of dimensions to values // queryVector represents a map of dimensions to values
public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector) { public SparseVectorFunction(ScoreScript scoreScript,
Map<String, Number> queryVector,
Object field) {
this.scoreScript = scoreScript; this.scoreScript = scoreScript;
//break vector into two arrays dims and values //break vector into two arrays dims and values
int n = queryVector.size(); int n = queryVector.size();
@ -203,28 +233,46 @@ public class ScoreScriptUtils {
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
sortSparseDimsFloatValues(queryDims, queryValues, n); sortSparseDimsFloatValues(queryDims, queryValues, n);
if (field instanceof String) {
String fieldName = (String) field;
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
} else if (field instanceof SparseVectorScriptDocValues) {
docValues = (SparseVectorScriptDocValues) field;
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
} else {
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
"VectorScriptDocValues");
}
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE); deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
} }
public void validateDocVector(BytesRef vector) { BytesRef getEncodedVector() {
try {
docValues.setNextDocId(scoreScript._getDocId());
} catch (IOException e) {
throw ExceptionsHelper.convertToElastic(e);
}
BytesRef vector = docValues.getEncodedValue();
if (vector == null) { if (vector == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!"); throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
} }
return vector;
} }
} }
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors // Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
public static final class L1NormSparse extends SparseVectorFunction { public static final class L1NormSparse extends SparseVectorFunction {
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector) { public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector); super(scoreScript, queryVector, docVector);
} }
public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { public double l1normSparse() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
int queryIndex = 0; int queryIndex = 0;
int docIndex = 0; int docIndex = 0;
double l1norm = 0; double l1norm = 0;
@ -255,16 +303,15 @@ public class ScoreScriptUtils {
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors // Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
public static final class L2NormSparse extends SparseVectorFunction { public static final class L2NormSparse extends SparseVectorFunction {
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) { public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector); super(scoreScript, queryVector, docVector);
} }
public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { public double l2normSparse() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
int queryIndex = 0; int queryIndex = 0;
int docIndex = 0; int docIndex = 0;
double l2norm = 0; double l2norm = 0;
@ -298,16 +345,15 @@ public class ScoreScriptUtils {
// Calculate a dot product between a query's sparse vector and documents' sparse vectors // Calculate a dot product between a query's sparse vector and documents' sparse vectors
public static final class DotProductSparse extends SparseVectorFunction { public static final class DotProductSparse extends SparseVectorFunction {
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector) { public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector); super(scoreScript, queryVector, docVector);
} }
public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { public double dotProductSparse() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
return intDotProductSparse(queryValues, queryDims, docValues, docDims); return intDotProductSparse(queryValues, queryDims, docValues, docDims);
} }
} }
@ -316,8 +362,8 @@ public class ScoreScriptUtils {
public static final class CosineSimilaritySparse extends SparseVectorFunction { public static final class CosineSimilaritySparse extends SparseVectorFunction {
final double queryVectorMagnitude; final double queryVectorMagnitude;
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector) { public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
super(scoreScript, queryVector); super(scoreScript, queryVector, docVector);
double dotProduct = 0; double dotProduct = 0;
for (int i = 0; i< queryDims.length; i++) { for (int i = 0; i< queryDims.length; i++) {
dotProduct += queryValues[i] * queryValues[i]; dotProduct += queryValues[i] * queryValues[i];
@ -325,10 +371,8 @@ public class ScoreScriptUtils {
this.queryVectorMagnitude = Math.sqrt(dotProduct); this.queryVectorMagnitude = Math.sqrt(dotProduct);
} }
public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) { public double cosineSimilaritySparse() {
BytesRef vector = dvs.getEncodedValue(); BytesRef vector = getEncodedVector();
validateDocVector(vector);
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector); int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector); float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);

View File

@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import {
} }
static_import { static_import {
double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
} }

View File

@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse; import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -39,6 +40,66 @@ public class ScoreScriptUtilsTests extends ESTestCase {
} }
private void testDenseVectorFunctions(Version indexVersion) { private void testDenseVectorFunctions(Version indexVersion) {
String field = "vector";
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field);
double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field);
double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector, field);
double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector, field);
double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
// test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
// test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field);
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
}
public void testDeprecatedDenseVectorFunctions() {
testDeprecatedDenseVectorFunctions(Version.V_7_4_0);
testDeprecatedDenseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedDenseVectorFunctions(Version indexVersion) {
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion); BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class); VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
@ -50,45 +111,53 @@ public class ScoreScriptUtilsTests extends ESTestCase {
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f); List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
// test dotProduct // test dotProduct
DotProduct dotProduct = new DotProduct(scoreScript, queryVector); DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs);
double result = dotProduct.dotProduct(dvs); double result = dotProduct.dotProduct();
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001); assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity // test cosineSimilarity
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector); CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs);
double result2 = cosineSimilarity.cosineSimilarity(dvs); double result2 = cosineSimilarity.cosineSimilarity();
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001); assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1Norm // test l1Norm
L1Norm l1norm = new L1Norm(scoreScript, queryVector); L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs);
double result3 = l1norm.l1norm(dvs); double result3 = l1norm.l1norm();
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001); assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm // test l2norm
L2Norm l2norm = new L2Norm(scoreScript, queryVector); L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs);
double result4 = l2norm.l2norm(dvs); double result4 = l2norm.l2norm();
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001); assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test dotProduct fails when queryVector has wrong number of dims // test dotProduct fails when queryVector has wrong number of dims
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3); List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector); DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs);
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs)); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity fails when queryVector has wrong number of dims // test cosineSimilarity fails when queryVector has wrong number of dims
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector); CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs)); e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm fails when queryVector has wrong number of dims // test l1norm fails when queryVector has wrong number of dims
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector); L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs)); e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm fails when queryVector has wrong number of dims // test l2norm fails when queryVector has wrong number of dims
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector); L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs);
e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs)); e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]")); assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
} }
public void testSparseVectorFunctions() { public void testSparseVectorFunctions() {
@ -97,12 +166,62 @@ public class ScoreScriptUtilsTests extends ESTestCase {
} }
private void testSparseVectorFunctions(Version indexVersion) { private void testSparseVectorFunctions(Version indexVersion) {
String field = "vector";
int[] docVectorDims = {2, 10, 50, 113, 4545}; int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector( BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length); indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector); when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5);
put("10", 111.3);
put("50", -13.0);
put("113", 14.8);
put("4545", -156.0);
}};
// test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
// test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
// test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
// test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
}
public void testDeprecatedSparseVectorFunctions() {
testDeprecatedSparseVectorFunctions(Version.V_7_4_0);
testDeprecatedSparseVectorFunctions(Version.CURRENT);
}
private void testDeprecatedSparseVectorFunctions(Version indexVersion) {
int[] docVectorDims = {2, 10, 50, 113, 4545};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class); ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(indexVersion); when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
@ -115,29 +234,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}}; }};
// test dotProduct // test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs);
double result = docProductSparse.dotProductSparse(dvs); double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test cosineSimilarity // test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l1norm // test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs);
double result3 = l1Norm.l1normSparse(dvs); double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001); assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
// test l2norm // test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs);
double result4 = l2Norm.l2normSparse(dvs); double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001); assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
} }
public void testSparseVectorMissingDimensions1() { public void testSparseVectorMissingDimensions1() {
String field = "vector";
// Document vector's biggest dimension > query vector's biggest dimension // Document vector's biggest dimension > query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
@ -145,8 +268,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector); when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class); ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{ Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5); put("2", 0.5);
put("10", 111.3); put("10", 111.3);
@ -157,29 +283,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}}; }};
// test dotProduct // test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse(dvs); double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity // test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm // test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse(dvs); double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm // test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse(dvs); double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
} }
public void testSparseVectorMissingDimensions2() { public void testSparseVectorMissingDimensions2() {
String field = "vector";
// Document vector's biggest dimension < query vector's biggest dimension // Document vector's biggest dimension < query vector's biggest dimension
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546}; int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f}; float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
@ -187,8 +317,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length); Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class); VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
when(dvs.getEncodedValue()).thenReturn(encodedDocVector); when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
ScoreScript scoreScript = mock(ScoreScript.class); ScoreScript scoreScript = mock(ScoreScript.class);
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT); when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
Map<String, Number> queryVector = new HashMap<String, Number>() {{ Map<String, Number> queryVector = new HashMap<String, Number>() {{
put("2", 0.5); put("2", 0.5);
put("10", 111.3); put("10", 111.3);
@ -199,25 +332,27 @@ public class ScoreScriptUtilsTests extends ESTestCase {
}}; }};
// test dotProduct // test dotProduct
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector); DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
double result = docProductSparse.dotProductSparse(dvs); double result = docProductSparse.dotProductSparse();
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001); assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test cosineSimilarity // test cosineSimilarity
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector); CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs); double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001); assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l1norm // test l1norm
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector); L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
double result3 = l1Norm.l1normSparse(dvs); double result3 = l1Norm.l1normSparse();
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001); assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
// test l2norm // test l2norm
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector); L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
double result4 = l2Norm.l2normSparse(dvs); double result4 = l2Norm.l2normSparse();
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001); assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE); assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
} }
} }