Update the signature of vector script functions. (#48653)
Previously the functions accepted a doc values reference, whereas they now accept the name of the vector field. Here's an example of how a vector function was called before and after the change. ``` Before: cosineSimilarity(params.query_vector, doc['field']) After: cosineSimilarity(params.query_vector, 'field') ``` This seems more intuitive, since we don't allow direct access to vector doc values and the the meaning of `doc['field']` is unclear. The PR makes the following changes (broken into distinct commits): * Add new function signatures of the form `function(params.query_vector, 'field')` and deprecates the old ones. Because Painless doesn't allow two methods with the same name and number of arguments, we allow a generic `Object` to be passed in to the function and decide on the behavior through an `instanceof` check. * Refactor the class bindings so that the document field is passed to the constructor instead of the instance method. This allows us to avoid retrieving the vector doc values on every function invocation, which gives a tiny speed-up in benchmarks. Note that this PR adds new signatures for the sparse vector functions too, even though sparse vectors are deprecated. It seemed simplest to understand (for both us and users) to keep everything symmetric between dense and sparse vectors.
This commit is contained in:
parent
25724c5c46
commit
89c65752dc
|
@ -10,8 +10,8 @@ The following specialized API is available in the Score context.
|
||||||
==== Static Methods
|
==== Static Methods
|
||||||
The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values.
|
The following methods are directly callable without a class/instance qualifier. Note parameters denoted by a (*) are treated as read-only values.
|
||||||
|
|
||||||
* double cosineSimilarity(List *, VectorScriptDocValues.DenseVectorScriptDocValues)
|
* double cosineSimilarity(List *, String)
|
||||||
* double cosineSimilaritySparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues)
|
* double cosineSimilaritySparse(Map *, String)
|
||||||
* double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
* double decayDateExp(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
||||||
* double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
* double decayDateGauss(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
||||||
* double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
* double decayDateLinear(String *, String *, String *, double *, JodaCompatibleZonedDateTime)
|
||||||
|
@ -21,8 +21,8 @@ The following methods are directly callable without a class/instance qualifier.
|
||||||
* double decayNumericExp(double *, double *, double *, double *, double)
|
* double decayNumericExp(double *, double *, double *, double *, double)
|
||||||
* double decayNumericGauss(double *, double *, double *, double *, double)
|
* double decayNumericGauss(double *, double *, double *, double *, double)
|
||||||
* double decayNumericLinear(double *, double *, double *, double *, double)
|
* double decayNumericLinear(double *, double *, double *, double *, double)
|
||||||
* double dotProduct(List, VectorScriptDocValues.DenseVectorScriptDocValues)
|
* double dotProduct(List, String)
|
||||||
* double dotProductSparse(Map *, VectorScriptDocValues.SparseVectorScriptDocValues)
|
* double dotProductSparse(Map *, String)
|
||||||
* double randomScore(int *)
|
* double randomScore(int *)
|
||||||
* double randomScore(int *, String *)
|
* double randomScore(int *, String *)
|
||||||
* double saturation(double, double)
|
* double saturation(double, double)
|
||||||
|
|
|
@ -29,3 +29,10 @@ We have not seen much interest in this experimental field type, and don't see
|
||||||
a clear use case as it's currently designed. If you have feedback or
|
a clear use case as it's currently designed. If you have feedback or
|
||||||
suggestions around sparse vector functionality, please let us know through
|
suggestions around sparse vector functionality, please let us know through
|
||||||
GitHub or the 'discuss' forums.
|
GitHub or the 'discuss' forums.
|
||||||
|
|
||||||
|
[discrete]
|
||||||
|
==== Update to vector function signatures
|
||||||
|
The vector functions of the form `function(query, doc['field'])` are
|
||||||
|
deprecated, and the form `function(query, 'field')` should be used instead.
|
||||||
|
For example, `cosineSimilarity(query, doc['field'])` is replaced by
|
||||||
|
`cosineSimilarity(query, 'field')`.
|
|
@ -68,7 +68,7 @@ GET my_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "cosineSimilarity(params.query_vector, doc['my_dense_vector']) + 1.0", <2>
|
"source": "cosineSimilarity(params.query_vector, 'my_dense_vector') + 1.0", <2>
|
||||||
"params": {
|
"params": {
|
||||||
"query_vector": [4, 3.4, -0.2] <3>
|
"query_vector": [4, 3.4, -0.2] <3>
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ GET my_index/_search
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": """
|
"source": """
|
||||||
double value = dotProduct(params.query_vector, doc['my_dense_vector']);
|
double value = dotProduct(params.query_vector, 'my_dense_vector');
|
||||||
return sigmoid(1, Math.E, -value); <1>
|
return sigmoid(1, Math.E, -value); <1>
|
||||||
""",
|
""",
|
||||||
"params": {
|
"params": {
|
||||||
|
@ -139,7 +139,7 @@ GET my_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "1 / (1 + l1norm(params.queryVector, doc['my_dense_vector']))", <1>
|
"source": "1 / (1 + l1norm(params.queryVector, 'my_dense_vector'))", <1>
|
||||||
"params": {
|
"params": {
|
||||||
"queryVector": [4, 3.4, -0.2]
|
"queryVector": [4, 3.4, -0.2]
|
||||||
}
|
}
|
||||||
|
@ -178,7 +178,7 @@ GET my_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "1 / (1 + l2norm(params.queryVector, doc['my_dense_vector']))",
|
"source": "1 / (1 + l2norm(params.queryVector, 'my_dense_vector'))",
|
||||||
"params": {
|
"params": {
|
||||||
"queryVector": [4, 3.4, -0.2]
|
"queryVector": [4, 3.4, -0.2]
|
||||||
}
|
}
|
||||||
|
@ -196,7 +196,7 @@ You can check if a document has a value for the field `my_vector` by
|
||||||
|
|
||||||
[source,js]
|
[source,js]
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, doc['my_vector'])"
|
"source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')"
|
||||||
--------------------------------------------------
|
--------------------------------------------------
|
||||||
// NOTCONSOLE
|
// NOTCONSOLE
|
||||||
|
|
||||||
|
@ -262,7 +262,7 @@ GET my_sparse_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector']) + 1.0",
|
"source": "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector') + 1.0",
|
||||||
"params": {
|
"params": {
|
||||||
"query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
"query_vector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
||||||
}
|
}
|
||||||
|
@ -294,7 +294,7 @@ GET my_sparse_index/_search
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": """
|
"source": """
|
||||||
double value = dotProductSparse(params.query_vector, doc['my_sparse_vector']);
|
double value = dotProductSparse(params.query_vector, 'my_sparse_vector');
|
||||||
return sigmoid(1, Math.E, -value);
|
return sigmoid(1, Math.E, -value);
|
||||||
""",
|
""",
|
||||||
"params": {
|
"params": {
|
||||||
|
@ -327,7 +327,7 @@ GET my_sparse_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "1 / (1 + l1normSparse(params.queryVector, doc['my_sparse_vector']))",
|
"source": "1 / (1 + l1normSparse(params.queryVector, 'my_sparse_vector'))",
|
||||||
"params": {
|
"params": {
|
||||||
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
||||||
}
|
}
|
||||||
|
@ -358,7 +358,7 @@ GET my_sparse_index/_search
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": "1 / (1 + l2normSparse(params.queryVector, doc['my_sparse_vector']))",
|
"source": "1 / (1 + l2normSparse(params.queryVector, 'my_sparse_vector'))",
|
||||||
"params": {
|
"params": {
|
||||||
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
"queryVector": {"2": 0.5, "10" : 111.3, "50": -1.3, "113": 14.8, "4545": 156.0}
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,7 +119,7 @@ public abstract class ScoreScript {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The doc lookup for the Lucene segment this script was created for. */
|
/** The doc lookup for the Lucene segment this script was created for. */
|
||||||
public final Map<String, ScriptDocValues<?>> getDoc() {
|
public Map<String, ScriptDocValues<?>> getDoc() {
|
||||||
return leafLookup.doc();
|
return leafLookup.doc();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
setup:
|
setup:
|
||||||
- skip:
|
- skip:
|
||||||
features: headers
|
features: [headers, warnings]
|
||||||
version: " - 7.2.99"
|
version: " - 7.2.99"
|
||||||
reason: "dense_vector dims parameter was added from 7.3"
|
reason: "dense_vector dims parameter was added from 7.3"
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProduct(params.query_vector, doc['my_dense_vector'])"
|
source: "dotProduct(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
@ -99,3 +99,26 @@ setup:
|
||||||
- match: {hits.hits.2._id: "1"}
|
- match: {hits.hits.2._id: "1"}
|
||||||
- gte: {hits.hits.2._score: 0.78}
|
- gte: {hits.hits.2._score: 0.78}
|
||||||
- lte: {hits.hits.2._score: 0.791}
|
- lte: {hits.hits.2._score: 0.791}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Deprecated function signature":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
warnings:
|
||||||
|
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
||||||
|
params:
|
||||||
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.1._id: "2"}
|
||||||
|
- match: {hits.hits.2._id: "1"}
|
||||||
|
|
|
@ -53,7 +53,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l1norm(params.query_vector, doc['my_dense_vector'])"
|
source: "l1norm(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l2norm(params.query_vector, doc['my_dense_vector'])"
|
source: "l2norm(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [10, 10, 10]
|
query_vector: [10, 10, 10]
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [10.0, 10.0, 10.0]
|
query_vector: [10.0, 10.0, 10.0]
|
||||||
|
|
||||||
|
@ -111,7 +111,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [1, 2, 3, 4]
|
query_vector: [1, 2, 3, 4]
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
@ -125,7 +125,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProduct(params.query_vector, doc['my_dense_vector'])"
|
source: "dotProduct(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [1, 2, 3, 4]
|
query_vector: [1, 2, 3, 4]
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
@ -161,7 +161,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [10.0, 10.0, 10.0]
|
query_vector: [10.0, 10.0, 10.0]
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
@ -177,7 +177,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, doc['my_dense_vector'])"
|
source: "doc['my_dense_vector'].size() == 0 ? 0 : cosineSimilarity(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [10.0, 10.0, 10.0]
|
query_vector: [10.0, 10.0, 10.0]
|
||||||
|
|
||||||
|
@ -208,7 +208,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProductSparse(params.query_vector, doc['my_dense_vector'])"
|
source: "dotProductSparse(params.query_vector, 'my_dense_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"2": 0.5, "10" : 111.3, "3": 44}
|
query_vector: {"2": 0.5, "10" : 111.3, "3": 44}
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
|
|
@ -55,7 +55,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
||||||
|
|
||||||
|
@ -87,7 +87,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
||||||
|
|
||||||
|
@ -104,3 +104,27 @@ setup:
|
||||||
- match: {hits.hits.2._id: "1"}
|
- match: {hits.hits.2._id: "1"}
|
||||||
- gte: {hits.hits.2._score: 0.78}
|
- gte: {hits.hits.2._score: 0.78}
|
||||||
- lte: {hits.hits.2._score: 0.791}
|
- lte: {hits.hits.2._score: 0.791}
|
||||||
|
|
||||||
|
---
|
||||||
|
"Deprecated function signature":
|
||||||
|
- do:
|
||||||
|
headers:
|
||||||
|
Content-Type: application/json
|
||||||
|
warnings:
|
||||||
|
- The [sparse_vector] field type is deprecated and will be removed in 8.0.
|
||||||
|
- The vector functions of the form function(query, doc['field']) are deprecated, and the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by cosineSimilarity(query, 'field').
|
||||||
|
search:
|
||||||
|
rest_total_hits_as_int: true
|
||||||
|
body:
|
||||||
|
query:
|
||||||
|
script_score:
|
||||||
|
query: {match_all: {} }
|
||||||
|
script:
|
||||||
|
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
||||||
|
params:
|
||||||
|
query_vector: {"2": -0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
||||||
|
|
||||||
|
- match: {hits.total: 3}
|
||||||
|
- match: {hits.hits.0._id: "3"}
|
||||||
|
- match: {hits.hits.1._id: "2"}
|
||||||
|
- match: {hits.hits.2._id: "1"}
|
||||||
|
|
|
@ -55,7 +55,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
query_vector: {"2": 0.5, "10" : 111.3, "50": -13.0, "113": 14.8, "4545": -156.0}
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10}
|
query_vector: {"1": 10}
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10.0}
|
query_vector: {"1": 10.0}
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10.0}
|
query_vector: {"1": 10.0}
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
@ -145,7 +145,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "doc['my_sparse_vector'].size() == 0 ? 0 : cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10.0}
|
query_vector: {"1": 10.0}
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555}
|
query_vector: {"100": -200.0, "11" : 300.33, "12": -34.8988, "2": 230.0, "30": 15.555}
|
||||||
|
|
||||||
|
@ -229,7 +229,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProduct(params.query_vector, doc['my_sparse_vector'])"
|
source: "dotProduct(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: [0.5, 111]
|
query_vector: [0.5, 111]
|
||||||
- match: { error.root_cause.0.type: "script_exception" }
|
- match: { error.root_cause.0.type: "script_exception" }
|
||||||
|
@ -272,7 +272,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "dotProductSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "dotProductSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10, "5": 5}
|
query_vector: {"1": 10, "5": 5}
|
||||||
|
|
||||||
|
@ -303,7 +303,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "cosineSimilaritySparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "cosineSimilaritySparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10, "5" : 5}
|
query_vector: {"1": 10, "5" : 5}
|
||||||
|
|
||||||
|
@ -333,7 +333,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l1normSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "l1normSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10, "5": 5}
|
query_vector: {"1": 10, "5": 5}
|
||||||
|
|
||||||
|
@ -360,7 +360,7 @@ setup:
|
||||||
script_score:
|
script_score:
|
||||||
query: {match_all: {} }
|
query: {match_all: {} }
|
||||||
script:
|
script:
|
||||||
source: "l2normSparse(params.query_vector, doc['my_sparse_vector'])"
|
source: "l2normSparse(params.query_vector, 'my_sparse_vector')"
|
||||||
params:
|
params:
|
||||||
query_vector: {"1": 10, "5": 5}
|
query_vector: {"1": 10, "5": 5}
|
||||||
|
|
||||||
|
|
|
@ -9,12 +9,16 @@ package org.elasticsearch.xpack.vectors.query;
|
||||||
|
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.elasticsearch.ExceptionsHelper;
|
||||||
import org.elasticsearch.Version;
|
import org.elasticsearch.Version;
|
||||||
import org.elasticsearch.common.logging.DeprecationLogger;
|
import org.elasticsearch.common.logging.DeprecationLogger;
|
||||||
import org.elasticsearch.script.ScoreScript;
|
import org.elasticsearch.script.ScoreScript;
|
||||||
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
|
import org.elasticsearch.xpack.vectors.mapper.SparseVectorFieldMapper;
|
||||||
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
|
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
|
||||||
|
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.DenseVectorScriptDocValues;
|
||||||
|
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues.SparseVectorScriptDocValues;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -22,6 +26,10 @@ import java.util.Map;
|
||||||
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
|
import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
|
||||||
|
|
||||||
public class ScoreScriptUtils {
|
public class ScoreScriptUtils {
|
||||||
|
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class));
|
||||||
|
static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " +
|
||||||
|
"the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " +
|
||||||
|
"cosineSimilarity(query, 'field').";
|
||||||
|
|
||||||
//**************FUNCTIONS FOR DENSE VECTORS
|
//**************FUNCTIONS FOR DENSE VECTORS
|
||||||
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
|
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
|
||||||
|
@ -31,9 +39,12 @@ public class ScoreScriptUtils {
|
||||||
public static class DenseVectorFunction {
|
public static class DenseVectorFunction {
|
||||||
final ScoreScript scoreScript;
|
final ScoreScript scoreScript;
|
||||||
final float[] queryVector;
|
final float[] queryVector;
|
||||||
|
final VectorScriptDocValues.DenseVectorScriptDocValues docValues;
|
||||||
|
|
||||||
public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
|
public DenseVectorFunction(ScoreScript scoreScript,
|
||||||
this(scoreScript, queryVector, false);
|
List<Number> queryVector,
|
||||||
|
Object field) {
|
||||||
|
this(scoreScript, queryVector, field, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -45,6 +56,7 @@ public class ScoreScriptUtils {
|
||||||
*/
|
*/
|
||||||
public DenseVectorFunction(ScoreScript scoreScript,
|
public DenseVectorFunction(ScoreScript scoreScript,
|
||||||
List<Number> queryVector,
|
List<Number> queryVector,
|
||||||
|
Object field,
|
||||||
boolean normalizeQuery) {
|
boolean normalizeQuery) {
|
||||||
this.scoreScript = scoreScript;
|
this.scoreScript = scoreScript;
|
||||||
|
|
||||||
|
@ -62,9 +74,28 @@ public class ScoreScriptUtils {
|
||||||
this.queryVector[dim] /= queryMagnitude;
|
this.queryVector[dim] /= queryMagnitude;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (field instanceof String) {
|
||||||
|
String fieldName = (String) field;
|
||||||
|
docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
|
||||||
|
} else if (field instanceof DenseVectorScriptDocValues) {
|
||||||
|
docValues = (DenseVectorScriptDocValues) field;
|
||||||
|
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
|
||||||
|
"VectorScriptDocValues");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void validateDocVector(BytesRef vector) {
|
BytesRef getEncodedVector() {
|
||||||
|
try {
|
||||||
|
docValues.setNextDocId(scoreScript._getDocId());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw ExceptionsHelper.convertToElastic(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the encoded vector's length.
|
||||||
|
BytesRef vector = docValues.getEncodedValue();
|
||||||
if (vector == null) {
|
if (vector == null) {
|
||||||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
|
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
|
||||||
}
|
}
|
||||||
|
@ -74,20 +105,21 @@ public class ScoreScriptUtils {
|
||||||
throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
|
throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
|
||||||
queryVector.length + "] than the document vectors [" + vectorLength + "].");
|
queryVector.length + "] than the document vectors [" + vectorLength + "].");
|
||||||
}
|
}
|
||||||
|
return vector;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
||||||
public static final class L1Norm extends DenseVectorFunction {
|
public static final class L1Norm extends DenseVectorFunction {
|
||||||
|
|
||||||
public L1Norm(ScoreScript scoreScript, List<Number> queryVector) {
|
public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, field);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
|
public double l1norm() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
||||||
|
|
||||||
double l1norm = 0;
|
double l1norm = 0;
|
||||||
|
|
||||||
for (float queryValue : queryVector) {
|
for (float queryValue : queryVector) {
|
||||||
|
@ -100,13 +132,12 @@ public class ScoreScriptUtils {
|
||||||
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
|
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
|
||||||
public static final class L2Norm extends DenseVectorFunction {
|
public static final class L2Norm extends DenseVectorFunction {
|
||||||
|
|
||||||
public L2Norm(ScoreScript scoreScript, List<Number> queryVector) {
|
public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, field);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
|
public double l2norm() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
||||||
|
|
||||||
double l2norm = 0;
|
double l2norm = 0;
|
||||||
|
@ -121,13 +152,12 @@ public class ScoreScriptUtils {
|
||||||
// Calculate a dot product between a query's dense vector and documents' dense vectors
|
// Calculate a dot product between a query's dense vector and documents' dense vectors
|
||||||
public static final class DotProduct extends DenseVectorFunction {
|
public static final class DotProduct extends DenseVectorFunction {
|
||||||
|
|
||||||
public DotProduct(ScoreScript scoreScript, List<Number> queryVector) {
|
public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, field);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs){
|
public double dotProduct() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
||||||
|
|
||||||
double dotProduct = 0;
|
double dotProduct = 0;
|
||||||
|
@ -141,14 +171,12 @@ public class ScoreScriptUtils {
|
||||||
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
|
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
|
||||||
public static final class CosineSimilarity extends DenseVectorFunction {
|
public static final class CosineSimilarity extends DenseVectorFunction {
|
||||||
|
|
||||||
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) {
|
public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
||||||
super(scoreScript, queryVector, true);
|
super(scoreScript, queryVector, field, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
|
public double cosineSimilarity() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
|
|
||||||
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
|
||||||
|
|
||||||
double dotProduct = 0.0;
|
double dotProduct = 0.0;
|
||||||
|
@ -176,15 +204,17 @@ public class ScoreScriptUtils {
|
||||||
// per script execution for all documents.
|
// per script execution for all documents.
|
||||||
|
|
||||||
public static class SparseVectorFunction {
|
public static class SparseVectorFunction {
|
||||||
static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(SparseVectorFunction.class));
|
|
||||||
|
|
||||||
final ScoreScript scoreScript;
|
final ScoreScript scoreScript;
|
||||||
final float[] queryValues;
|
final float[] queryValues;
|
||||||
final int[] queryDims;
|
final int[] queryDims;
|
||||||
|
|
||||||
|
final VectorScriptDocValues.SparseVectorScriptDocValues docValues;
|
||||||
|
|
||||||
// prepare queryVector once per script execution
|
// prepare queryVector once per script execution
|
||||||
// queryVector represents a map of dimensions to values
|
// queryVector represents a map of dimensions to values
|
||||||
public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector) {
|
public SparseVectorFunction(ScoreScript scoreScript,
|
||||||
|
Map<String, Number> queryVector,
|
||||||
|
Object field) {
|
||||||
this.scoreScript = scoreScript;
|
this.scoreScript = scoreScript;
|
||||||
//break vector into two arrays dims and values
|
//break vector into two arrays dims and values
|
||||||
int n = queryVector.size();
|
int n = queryVector.size();
|
||||||
|
@ -203,28 +233,46 @@ public class ScoreScriptUtils {
|
||||||
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
|
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
|
||||||
sortSparseDimsFloatValues(queryDims, queryValues, n);
|
sortSparseDimsFloatValues(queryDims, queryValues, n);
|
||||||
|
|
||||||
|
if (field instanceof String) {
|
||||||
|
String fieldName = (String) field;
|
||||||
|
docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
|
||||||
|
} else if (field instanceof SparseVectorScriptDocValues) {
|
||||||
|
docValues = (SparseVectorScriptDocValues) field;
|
||||||
|
deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
|
||||||
|
"VectorScriptDocValues");
|
||||||
|
}
|
||||||
|
|
||||||
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void validateDocVector(BytesRef vector) {
|
BytesRef getEncodedVector() {
|
||||||
|
try {
|
||||||
|
docValues.setNextDocId(scoreScript._getDocId());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw ExceptionsHelper.convertToElastic(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
BytesRef vector = docValues.getEncodedValue();
|
||||||
if (vector == null) {
|
if (vector == null) {
|
||||||
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
|
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
|
||||||
}
|
}
|
||||||
|
return vector;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
|
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
|
||||||
public static final class L1NormSparse extends SparseVectorFunction {
|
public static final class L1NormSparse extends SparseVectorFunction {
|
||||||
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector) {
|
public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, docVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
|
public double l1normSparse() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
|
|
||||||
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
||||||
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
||||||
|
|
||||||
int queryIndex = 0;
|
int queryIndex = 0;
|
||||||
int docIndex = 0;
|
int docIndex = 0;
|
||||||
double l1norm = 0;
|
double l1norm = 0;
|
||||||
|
@ -255,16 +303,15 @@ public class ScoreScriptUtils {
|
||||||
|
|
||||||
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
|
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
|
||||||
public static final class L2NormSparse extends SparseVectorFunction {
|
public static final class L2NormSparse extends SparseVectorFunction {
|
||||||
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
|
public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, docVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
|
public double l2normSparse() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
|
|
||||||
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
||||||
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
||||||
|
|
||||||
int queryIndex = 0;
|
int queryIndex = 0;
|
||||||
int docIndex = 0;
|
int docIndex = 0;
|
||||||
double l2norm = 0;
|
double l2norm = 0;
|
||||||
|
@ -298,16 +345,15 @@ public class ScoreScriptUtils {
|
||||||
|
|
||||||
// Calculate a dot product between a query's sparse vector and documents' sparse vectors
|
// Calculate a dot product between a query's sparse vector and documents' sparse vectors
|
||||||
public static final class DotProductSparse extends SparseVectorFunction {
|
public static final class DotProductSparse extends SparseVectorFunction {
|
||||||
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
|
public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, docVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
|
public double dotProductSparse() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
|
|
||||||
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
||||||
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
||||||
|
|
||||||
return intDotProductSparse(queryValues, queryDims, docValues, docDims);
|
return intDotProductSparse(queryValues, queryDims, docValues, docDims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -316,8 +362,8 @@ public class ScoreScriptUtils {
|
||||||
public static final class CosineSimilaritySparse extends SparseVectorFunction {
|
public static final class CosineSimilaritySparse extends SparseVectorFunction {
|
||||||
final double queryVectorMagnitude;
|
final double queryVectorMagnitude;
|
||||||
|
|
||||||
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
|
public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
||||||
super(scoreScript, queryVector);
|
super(scoreScript, queryVector, docVector);
|
||||||
double dotProduct = 0;
|
double dotProduct = 0;
|
||||||
for (int i = 0; i< queryDims.length; i++) {
|
for (int i = 0; i< queryDims.length; i++) {
|
||||||
dotProduct += queryValues[i] * queryValues[i];
|
dotProduct += queryValues[i] * queryValues[i];
|
||||||
|
@ -325,10 +371,8 @@ public class ScoreScriptUtils {
|
||||||
this.queryVectorMagnitude = Math.sqrt(dotProduct);
|
this.queryVectorMagnitude = Math.sqrt(dotProduct);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
|
public double cosineSimilaritySparse() {
|
||||||
BytesRef vector = dvs.getEncodedValue();
|
BytesRef vector = getEncodedVector();
|
||||||
validateDocVector(vector);
|
|
||||||
|
|
||||||
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(scoreScript._getIndexVersion(), vector);
|
||||||
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
float[] docValues = VectorEncoderDecoder.decodeSparseVector(scoreScript._getIndexVersion(), vector);
|
||||||
|
|
||||||
|
|
|
@ -13,12 +13,12 @@ class org.elasticsearch.script.ScoreScript @no_import {
|
||||||
}
|
}
|
||||||
|
|
||||||
static_import {
|
static_import {
|
||||||
double l1norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
|
double l1norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
|
||||||
double l2norm(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
|
double l2norm(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm
|
||||||
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
|
double cosineSimilarity(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilarity
|
||||||
double dotProduct(org.elasticsearch.script.ScoreScript, List, VectorScriptDocValues.DenseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
|
double dotProduct(org.elasticsearch.script.ScoreScript, List, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProduct
|
||||||
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
|
double l1normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1NormSparse
|
||||||
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
|
double l2normSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2NormSparse
|
||||||
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
|
double dotProductSparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$DotProductSparse
|
||||||
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, VectorScriptDocValues.SparseVectorScriptDocValues) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
|
double cosineSimilaritySparse(org.elasticsearch.script.ScoreScript, Map, Object) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$CosineSimilaritySparse
|
||||||
}
|
}
|
|
@ -22,6 +22,7 @@ import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
|
||||||
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
|
import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2NormSparse;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -39,6 +40,66 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testDenseVectorFunctions(Version indexVersion) {
|
private void testDenseVectorFunctions(Version indexVersion) {
|
||||||
|
String field = "vector";
|
||||||
|
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
||||||
|
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
|
||||||
|
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
|
||||||
|
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
||||||
|
|
||||||
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||||
|
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
|
||||||
|
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
|
||||||
|
|
||||||
|
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
||||||
|
|
||||||
|
// test dotProduct
|
||||||
|
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, field);
|
||||||
|
double result = dotProduct.dotProduct();
|
||||||
|
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
|
||||||
|
// test cosineSimilarity
|
||||||
|
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, field);
|
||||||
|
double result2 = cosineSimilarity.cosineSimilarity();
|
||||||
|
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
|
||||||
|
|
||||||
|
// test l1Norm
|
||||||
|
L1Norm l1norm = new L1Norm(scoreScript, queryVector, field);
|
||||||
|
double result3 = l1norm.l1norm();
|
||||||
|
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
|
||||||
|
|
||||||
|
// test l2norm
|
||||||
|
L2Norm l2norm = new L2Norm(scoreScript, queryVector, field);
|
||||||
|
double result4 = l2norm.l2norm();
|
||||||
|
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
|
||||||
|
|
||||||
|
// test dotProduct fails when queryVector has wrong number of dims
|
||||||
|
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
|
||||||
|
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, field);
|
||||||
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
|
||||||
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
|
||||||
|
// test cosineSimilarity fails when queryVector has wrong number of dims
|
||||||
|
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, field);
|
||||||
|
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
|
||||||
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
|
||||||
|
// test l1norm fails when queryVector has wrong number of dims
|
||||||
|
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, field);
|
||||||
|
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
|
||||||
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
|
||||||
|
// test l2norm fails when queryVector has wrong number of dims
|
||||||
|
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, field);
|
||||||
|
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
|
||||||
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDeprecatedDenseVectorFunctions() {
|
||||||
|
testDeprecatedDenseVectorFunctions(Version.V_7_4_0);
|
||||||
|
testDeprecatedDenseVectorFunctions(Version.CURRENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testDeprecatedDenseVectorFunctions(Version indexVersion) {
|
||||||
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
float[] docVector = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
||||||
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
|
BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
|
||||||
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
|
VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
|
||||||
|
@ -50,45 +111,53 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
|
||||||
|
|
||||||
// test dotProduct
|
// test dotProduct
|
||||||
DotProduct dotProduct = new DotProduct(scoreScript, queryVector);
|
DotProduct dotProduct = new DotProduct(scoreScript, queryVector, dvs);
|
||||||
double result = dotProduct.dotProduct(dvs);
|
double result = dotProduct.dotProduct();
|
||||||
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
|
assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test cosineSimilarity
|
// test cosineSimilarity
|
||||||
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector);
|
CosineSimilarity cosineSimilarity = new CosineSimilarity(scoreScript, queryVector, dvs);
|
||||||
double result2 = cosineSimilarity.cosineSimilarity(dvs);
|
double result2 = cosineSimilarity.cosineSimilarity();
|
||||||
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
|
assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result2, 0.001);
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l1Norm
|
// test l1Norm
|
||||||
L1Norm l1norm = new L1Norm(scoreScript, queryVector);
|
L1Norm l1norm = new L1Norm(scoreScript, queryVector, dvs);
|
||||||
double result3 = l1norm.l1norm(dvs);
|
double result3 = l1norm.l1norm();
|
||||||
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
|
assertEquals("l1norm result is not equal to the expected value!", 485.184, result3, 0.001);
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l2norm
|
// test l2norm
|
||||||
L2Norm l2norm = new L2Norm(scoreScript, queryVector);
|
L2Norm l2norm = new L2Norm(scoreScript, queryVector, dvs);
|
||||||
double result4 = l2norm.l2norm(dvs);
|
double result4 = l2norm.l2norm();
|
||||||
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
|
assertEquals("l2norm result is not equal to the expected value!", 301.361, result4, 0.001);
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test dotProduct fails when queryVector has wrong number of dims
|
// test dotProduct fails when queryVector has wrong number of dims
|
||||||
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
|
List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
|
||||||
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector);
|
DotProduct dotProduct2 = new DotProduct(scoreScript, invalidQueryVector, dvs);
|
||||||
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> dotProduct2.dotProduct(dvs));
|
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct2::dotProduct);
|
||||||
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test cosineSimilarity fails when queryVector has wrong number of dims
|
// test cosineSimilarity fails when queryVector has wrong number of dims
|
||||||
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector);
|
CosineSimilarity cosineSimilarity2 = new CosineSimilarity(scoreScript, invalidQueryVector, dvs);
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> cosineSimilarity2.cosineSimilarity(dvs));
|
e = expectThrows(IllegalArgumentException.class, cosineSimilarity2::cosineSimilarity);
|
||||||
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l1norm fails when queryVector has wrong number of dims
|
// test l1norm fails when queryVector has wrong number of dims
|
||||||
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector);
|
L1Norm l1norm2 = new L1Norm(scoreScript, invalidQueryVector, dvs);
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> l1norm2.l1norm(dvs));
|
e = expectThrows(IllegalArgumentException.class, l1norm2::l1norm);
|
||||||
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l2norm fails when queryVector has wrong number of dims
|
// test l2norm fails when queryVector has wrong number of dims
|
||||||
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector);
|
L2Norm l2norm2 = new L2Norm(scoreScript, invalidQueryVector, dvs);
|
||||||
e = expectThrows(IllegalArgumentException.class, () -> l2norm2.l2norm(dvs));
|
e = expectThrows(IllegalArgumentException.class, l2norm2::l2norm);
|
||||||
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
|
||||||
|
assertWarnings(ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSparseVectorFunctions() {
|
public void testSparseVectorFunctions() {
|
||||||
|
@ -97,12 +166,62 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testSparseVectorFunctions(Version indexVersion) {
|
private void testSparseVectorFunctions(Version indexVersion) {
|
||||||
|
String field = "vector";
|
||||||
|
|
||||||
int[] docVectorDims = {2, 10, 50, 113, 4545};
|
int[] docVectorDims = {2, 10, 50, 113, 4545};
|
||||||
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
||||||
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
|
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
|
||||||
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
|
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
|
||||||
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
||||||
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
||||||
|
|
||||||
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||||
|
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
|
||||||
|
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
|
||||||
|
|
||||||
|
Map<String, Number> queryVector = new HashMap<String, Number>() {{
|
||||||
|
put("2", 0.5);
|
||||||
|
put("10", 111.3);
|
||||||
|
put("50", -13.0);
|
||||||
|
put("113", 14.8);
|
||||||
|
put("4545", -156.0);
|
||||||
|
}};
|
||||||
|
|
||||||
|
// test dotProduct
|
||||||
|
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
|
||||||
|
double result = docProductSparse.dotProductSparse();
|
||||||
|
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
|
||||||
|
// test cosineSimilarity
|
||||||
|
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
|
||||||
|
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
|
||||||
|
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
|
||||||
|
|
||||||
|
// test l1norm
|
||||||
|
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
|
||||||
|
double result3 = l1Norm.l1normSparse();
|
||||||
|
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
|
||||||
|
|
||||||
|
// test l2norm
|
||||||
|
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
|
||||||
|
double result4 = l2Norm.l2normSparse();
|
||||||
|
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testDeprecatedSparseVectorFunctions() {
|
||||||
|
testDeprecatedSparseVectorFunctions(Version.V_7_4_0);
|
||||||
|
testDeprecatedSparseVectorFunctions(Version.CURRENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void testDeprecatedSparseVectorFunctions(Version indexVersion) {
|
||||||
|
int[] docVectorDims = {2, 10, 50, 113, 4545};
|
||||||
|
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f};
|
||||||
|
BytesRef encodedDocVector = VectorEncoderDecoder.encodeSparseVector(
|
||||||
|
indexVersion, docVectorDims, docVectorValues, docVectorDims.length);
|
||||||
|
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
||||||
|
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
||||||
|
|
||||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||||
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
|
when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
|
||||||
|
|
||||||
|
@ -115,29 +234,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
}};
|
}};
|
||||||
|
|
||||||
// test dotProduct
|
// test dotProduct
|
||||||
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
|
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, dvs);
|
||||||
double result = docProductSparse.dotProductSparse(dvs);
|
double result = docProductSparse.dotProductSparse();
|
||||||
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test cosineSimilarity
|
// test cosineSimilarity
|
||||||
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
|
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, dvs);
|
||||||
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
|
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
|
||||||
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
|
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.790, result2, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l1norm
|
// test l1norm
|
||||||
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
|
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, dvs);
|
||||||
double result3 = l1Norm.l1normSparse(dvs);
|
double result3 = l1Norm.l1normSparse();
|
||||||
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
|
assertEquals("l1normSparse result is not equal to the expected value!", 485.184, result3, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l2norm
|
// test l2norm
|
||||||
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
|
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, dvs);
|
||||||
double result4 = l2Norm.l2normSparse(dvs);
|
double result4 = l2Norm.l2normSparse();
|
||||||
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
|
assertEquals("l2normSparse result is not equal to the expected value!", 301.361, result4, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE, ScoreScriptUtils.DEPRECATION_MESSAGE);
|
||||||
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSparseVectorMissingDimensions1() {
|
public void testSparseVectorMissingDimensions1() {
|
||||||
|
String field = "vector";
|
||||||
|
|
||||||
// Document vector's biggest dimension > query vector's biggest dimension
|
// Document vector's biggest dimension > query vector's biggest dimension
|
||||||
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
|
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
|
||||||
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
|
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
|
||||||
|
@ -145,8 +268,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
|
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
|
||||||
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
||||||
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
||||||
|
|
||||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||||
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
|
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
|
||||||
|
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
|
||||||
|
|
||||||
Map<String, Number> queryVector = new HashMap<String, Number>() {{
|
Map<String, Number> queryVector = new HashMap<String, Number>() {{
|
||||||
put("2", 0.5);
|
put("2", 0.5);
|
||||||
put("10", 111.3);
|
put("10", 111.3);
|
||||||
|
@ -157,29 +283,33 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
}};
|
}};
|
||||||
|
|
||||||
// test dotProduct
|
// test dotProduct
|
||||||
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
|
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
|
||||||
double result = docProductSparse.dotProductSparse(dvs);
|
double result = docProductSparse.dotProductSparse();
|
||||||
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test cosineSimilarity
|
// test cosineSimilarity
|
||||||
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
|
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
|
||||||
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
|
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
|
||||||
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
|
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l1norm
|
// test l1norm
|
||||||
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
|
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
|
||||||
double result3 = l1Norm.l1normSparse(dvs);
|
double result3 = l1Norm.l1normSparse();
|
||||||
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
|
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l2norm
|
// test l2norm
|
||||||
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
|
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
|
||||||
double result4 = l2Norm.l2normSparse(dvs);
|
double result4 = l2Norm.l2normSparse();
|
||||||
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
|
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
|
||||||
|
|
||||||
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testSparseVectorMissingDimensions2() {
|
public void testSparseVectorMissingDimensions2() {
|
||||||
|
String field = "vector";
|
||||||
|
|
||||||
// Document vector's biggest dimension < query vector's biggest dimension
|
// Document vector's biggest dimension < query vector's biggest dimension
|
||||||
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
|
int[] docVectorDims = {2, 10, 50, 113, 4545, 4546};
|
||||||
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
|
float[] docVectorValues = {230.0f, 300.33f, -34.8988f, 15.555f, -200.0f, 11.5f};
|
||||||
|
@ -187,8 +317,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
|
Version.CURRENT, docVectorDims, docVectorValues, docVectorDims.length);
|
||||||
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
VectorScriptDocValues.SparseVectorScriptDocValues dvs = mock(VectorScriptDocValues.SparseVectorScriptDocValues.class);
|
||||||
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
|
||||||
|
|
||||||
ScoreScript scoreScript = mock(ScoreScript.class);
|
ScoreScript scoreScript = mock(ScoreScript.class);
|
||||||
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
|
when(scoreScript._getIndexVersion()).thenReturn(Version.CURRENT);
|
||||||
|
when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, dvs));
|
||||||
|
|
||||||
Map<String, Number> queryVector = new HashMap<String, Number>() {{
|
Map<String, Number> queryVector = new HashMap<String, Number>() {{
|
||||||
put("2", 0.5);
|
put("2", 0.5);
|
||||||
put("10", 111.3);
|
put("10", 111.3);
|
||||||
|
@ -199,25 +332,27 @@ public class ScoreScriptUtilsTests extends ESTestCase {
|
||||||
}};
|
}};
|
||||||
|
|
||||||
// test dotProduct
|
// test dotProduct
|
||||||
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector);
|
DotProductSparse docProductSparse = new DotProductSparse(scoreScript, queryVector, field);
|
||||||
double result = docProductSparse.dotProductSparse(dvs);
|
double result = docProductSparse.dotProductSparse();
|
||||||
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test cosineSimilarity
|
// test cosineSimilarity
|
||||||
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector);
|
CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(scoreScript, queryVector, field);
|
||||||
double result2 = cosineSimilaritySparse.cosineSimilaritySparse(dvs);
|
double result2 = cosineSimilaritySparse.cosineSimilaritySparse();
|
||||||
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
|
assertEquals("cosineSimilaritySparse result is not equal to the expected value!", 0.786, result2, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l1norm
|
// test l1norm
|
||||||
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector);
|
L1NormSparse l1Norm = new L1NormSparse(scoreScript, queryVector, field);
|
||||||
double result3 = l1Norm.l1normSparse(dvs);
|
double result3 = l1Norm.l1normSparse();
|
||||||
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
|
assertEquals("l1normSparse result is not equal to the expected value!", 517.184, result3, 0.001);
|
||||||
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
|
|
||||||
// test l2norm
|
// test l2norm
|
||||||
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector);
|
L2NormSparse l2Norm = new L2NormSparse(scoreScript, queryVector, field);
|
||||||
double result4 = l2Norm.l2normSparse(dvs);
|
double result4 = l2Norm.l2normSparse();
|
||||||
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
|
assertEquals("l2normSparse result is not equal to the expected value!", 302.277, result4, 0.001);
|
||||||
|
|
||||||
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
assertWarnings(SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue