Add knn search parameter and remove knn query.

Original Pull Rrequest #2920
Closes #2919
This commit is contained in:
puppylpg 2024-05-29 02:52:47 +08:00 committed by GitHub
parent 9d139299b2
commit 687b014e70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 576 additions and 49 deletions

View File

@ -6,6 +6,17 @@ This section describes breaking changes from version 5.3.x to 5.4.x and how remo
[[elasticsearch-migration-guide-5.3-5.4.breaking-changes]]
== Breaking Changes
[[elasticsearch-migration-guide-5.3-5.4.breaking-changes.knn-search]]
=== knn search
The `withKnnQuery` method in `NativeQueryBuilder` has been replaced with `withKnnSearches` to build a `NativeQuery` with knn search.
`KnnQuery` and `KnnSearch` are two different classes in elasticsearch java client and are used for different queries, with different parameters supported:
- `KnnSearch`: is https://www.elastic.co/guide/en/elasticsearch/reference/8.13/search-search.html#search-api-knn[the top level `knn` query] in the elasticsearch request;
- `KnnQuery`: is https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-knn-query.html[the `knn` query inside `query` clause];
If `KnnQuery` is still preferable, please be sure to construct it inside `query` clause manually, by means of `withQuery(co.elastic.clients.elasticsearch._types.query_dsl.Query query)` clause in `NativeQueryBuilder`.
[[elasticsearch-migration-guide-5.3-5.4.deprecations]]
== Deprecations

View File

@ -37,6 +37,7 @@ import org.springframework.core.annotation.AliasFor;
* @author Brian Kimmig
* @author Morgan Lutz
* @author Sascha Woo
* @author Haibo Liu
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.FIELD, ElementType.ANNOTATION_TYPE, ElementType.METHOD })
@ -195,6 +196,27 @@ public @interface Field {
*/
int dims() default -1;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
String elementType() default FieldElementType.DEFAULT;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
KnnSimilarity knnSimilarity() default KnnSimilarity.DEFAULT;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
KnnIndexOptions[] knnIndexOptions() default {};
/**
* Controls how Elasticsearch dynamically adds fields to the inner object within the document.<br>
* To be used in combination with {@link FieldType#Object} or {@link FieldType#Nested}

View File

@ -0,0 +1,26 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.annotations;
/**
* @author Haibo Liu
* @since 5.4
*/
public final class FieldElementType {
public final static String DEFAULT = "";
public final static String FLOAT = "float";
public final static String BYTE = "byte";
}

View File

@ -29,6 +29,7 @@ import java.lang.annotation.Target;
* @author Aleksei Arsenev
* @author Brian Kimmig
* @author Morgan Lutz
* @author Haibo Liu
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.ANNOTATION_TYPE)
@ -149,4 +150,25 @@ public @interface InnerField {
* @since 4.2
*/
int dims() default -1;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
String elementType() default FieldElementType.DEFAULT;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
KnnSimilarity knnSimilarity() default KnnSimilarity.DEFAULT;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 5.4
*/
KnnIndexOptions[] knnIndexOptions() default {};
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.annotations;
/**
* @author Haibo Liu
* @since 5.4
*/
public enum KnnAlgorithmType {
HNSW("hnsw"),
INT8_HNSW("int8_hnsw"),
FLAT("flat"),
INT8_FLAT("int8_flat"),
DEFAULT("");
private final String type;
KnnAlgorithmType(String type) {
this.type = type;
}
public String getType() {
return type;
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.annotations;
/**
* @author Haibo Liu
* @since 5.4
*/
public @interface KnnIndexOptions {
KnnAlgorithmType type() default KnnAlgorithmType.DEFAULT;
/**
* Only applicable to {@link KnnAlgorithmType#HNSW} and {@link KnnAlgorithmType#INT8_HNSW} index types.
*/
int m() default -1;
/**
* Only applicable to {@link KnnAlgorithmType#HNSW} and {@link KnnAlgorithmType#INT8_HNSW} index types.
*/
int efConstruction() default -1;
/**
* Only applicable to {@link KnnAlgorithmType#INT8_HNSW} and {@link KnnAlgorithmType#INT8_FLAT} index types.
*/
float confidenceInterval() default -1F;
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.annotations;
/**
* @author Haibo Liu
* @since 5.4
*/
public enum KnnSimilarity {
L2_NORM("l2_norm"),
DOT_PRODUCT("dot_product"),
COSINE("cosine"),
MAX_INNER_PRODUCT("max_inner_product"),
DEFAULT("");
private final String similarity;
KnnSimilarity(String similarity) {
this.similarity = similarity;
}
public String getSimilarity() {
return similarity;
}
}

View File

@ -15,7 +15,6 @@
*/
package org.springframework.data.elasticsearch.client.elc;
import co.elastic.clients.elasticsearch._types.KnnQuery;
import co.elastic.clients.elasticsearch._types.KnnSearch;
import co.elastic.clients.elasticsearch._types.SortOptions;
import co.elastic.clients.elasticsearch._types.aggregations.Aggregation;
@ -30,7 +29,6 @@ import java.util.List;
import java.util.Map;
import org.springframework.data.elasticsearch.core.query.BaseQuery;
import org.springframework.data.elasticsearch.core.query.ScriptedField;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
@ -40,6 +38,7 @@ import org.springframework.util.Assert;
*
* @author Peter-Josef Meisch
* @author Sascha Woo
* @author Haibo Liu
* @since 4.4
*/
public class NativeQuery extends BaseQuery {
@ -54,7 +53,6 @@ public class NativeQuery extends BaseQuery {
private List<SortOptions> sortOptions = Collections.emptyList();
private Map<String, JsonData> searchExtensions = Collections.emptyMap();
@Nullable private KnnQuery knnQuery;
@Nullable private List<KnnSearch> knnSearches = Collections.emptyList();
public NativeQuery(NativeQueryBuilder builder) {
@ -72,7 +70,6 @@ public class NativeQuery extends BaseQuery {
"Cannot add an NativeQuery in a NativeQuery");
}
this.springDataQuery = builder.getSpringDataQuery();
this.knnQuery = builder.getKnnQuery();
this.knnSearches = builder.getKnnSearches();
}
@ -124,14 +121,6 @@ public class NativeQuery extends BaseQuery {
this.springDataQuery = springDataQuery;
}
/**
* @since 5.1
*/
@Nullable
public KnnQuery getKnnQuery() {
return knnQuery;
}
/**
* @since 5.3.1
*/

View File

@ -40,6 +40,7 @@ import org.springframework.util.Assert;
/**
* @author Peter-Josef Meisch
* @author Sascha Woo
* @author Haibo Liu
* @since 4.4
*/
public class NativeQueryBuilder extends BaseQueryBuilder<NativeQuery, NativeQueryBuilder> {
@ -213,13 +214,30 @@ public class NativeQueryBuilder extends BaseQueryBuilder<NativeQuery, NativeQuer
}
/**
* @since 5.1
* @since 5.4
*/
public NativeQueryBuilder withKnnQuery(KnnQuery knnQuery) {
this.knnQuery = knnQuery;
public NativeQueryBuilder withKnnSearches(List<KnnSearch> knnSearches) {
this.knnSearches = knnSearches;
return this;
}
/**
* @since 5.4
*/
public NativeQueryBuilder withKnnSearches(Function<KnnSearch.Builder, ObjectBuilder<KnnSearch>> fn) {
Assert.notNull(fn, "fn must not be null");
return withKnnSearches(fn.apply(new KnnSearch.Builder()).build());
}
/**
* @since 5.4
*/
public NativeQueryBuilder withKnnSearches(KnnSearch knnSearch) {
return withKnnSearches(List.of(knnSearch));
}
public NativeQuery build() {
Assert.isTrue(query == null || springDataQuery == null, "Cannot have both a native query and a Spring Data query");
return new NativeQuery(this);

View File

@ -1377,7 +1377,7 @@ class RequestConverter extends AbstractQueryProcessor {
private Function<MultisearchHeader.Builder, ObjectBuilder<MultisearchHeader>> msearchHeaderBuilder(Query query,
IndexCoordinates index, @Nullable String routing) {
return h -> {
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
var searchType = (query instanceof NativeQuery nativeQuery && !isEmpty(nativeQuery.getKnnSearches())) ? null
: searchType(query.getSearchType());
h //
@ -1409,7 +1409,7 @@ class RequestConverter extends AbstractQueryProcessor {
ElasticsearchPersistentEntity<?> persistentEntity = getPersistentEntity(clazz);
var searchType = (query instanceof NativeQuery nativeQuery && nativeQuery.getKnnQuery() != null) ? null
var searchType = (query instanceof NativeQuery nativeQuery && !isEmpty(nativeQuery.getKnnSearches())) ? null
: searchType(query.getSearchType());
builder //
@ -1728,17 +1728,6 @@ class RequestConverter extends AbstractQueryProcessor {
.sort(query.getSortOptions()) //
;
if (query.getKnnQuery() != null) {
var kq = query.getKnnQuery();
builder.knn(ksb -> ksb
.field(kq.field())
.queryVector(kq.queryVector())
.numCandidates(kq.numCandidates())
.filter(kq.filter())
.similarity(kq.similarity()));
}
if (!isEmpty(query.getKnnSearches())) {
builder.knn(query.getKnnSearches());
}
@ -1760,17 +1749,6 @@ class RequestConverter extends AbstractQueryProcessor {
.collapse(query.getFieldCollapse()) //
.sort(query.getSortOptions());
if (query.getKnnQuery() != null) {
var kq = query.getKnnQuery();
builder.knn(ksb -> ksb
.field(kq.field())
.queryVector(kq.queryVector())
.numCandidates(kq.numCandidates())
.filter(kq.filter())
.similarity(kq.similarity()));
}
if (!isEmpty(query.getKnnSearches())) {
builder.knn(query.getKnnSearches());
}

View File

@ -23,15 +23,7 @@ import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.springframework.data.elasticsearch.annotations.DateFormat;
import org.springframework.data.elasticsearch.annotations.Field;
import org.springframework.data.elasticsearch.annotations.FieldType;
import org.springframework.data.elasticsearch.annotations.IndexOptions;
import org.springframework.data.elasticsearch.annotations.IndexPrefixes;
import org.springframework.data.elasticsearch.annotations.InnerField;
import org.springframework.data.elasticsearch.annotations.NullValueType;
import org.springframework.data.elasticsearch.annotations.Similarity;
import org.springframework.data.elasticsearch.annotations.TermVector;
import org.springframework.data.elasticsearch.annotations.*;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -49,6 +41,7 @@ import com.fasterxml.jackson.databind.node.TextNode;
* @author Brian Kimmig
* @author Morgan Lutz
* @author Sascha Woo
* @author Haibo Liu
* @since 4.0
*/
public final class MappingParameters {
@ -78,6 +71,10 @@ public final class MappingParameters {
static final String FIELD_PARAM_ORIENTATION = "orientation";
static final String FIELD_PARAM_POSITIVE_SCORE_IMPACT = "positive_score_impact";
static final String FIELD_PARAM_DIMS = "dims";
static final String FIELD_PARAM_ELEMENT_TYPE = "element_type";
static final String FIELD_PARAM_M = "m";
static final String FIELD_PARAM_EF_CONSTRUCTION = "ef_construction";
static final String FIELD_PARAM_CONFIDENCE_INTERVAL = "confidence_interval";
static final String FIELD_PARAM_SCALING_FACTOR = "scaling_factor";
static final String FIELD_PARAM_SEARCH_ANALYZER = "search_analyzer";
static final String FIELD_PARAM_STORE = "store";
@ -110,6 +107,9 @@ public final class MappingParameters {
private final Integer positionIncrementGap;
private final boolean positiveScoreImpact;
private final Integer dims;
private final String elementType;
private final KnnSimilarity knnSimilarity;
@Nullable private final KnnIndexOptions knnIndexOptions;
private final String searchAnalyzer;
private final double scalingFactor;
private final String similarity;
@ -174,6 +174,9 @@ public final class MappingParameters {
Assert.isTrue(dims >= 1 && dims <= 4096,
"Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 4096.");
}
elementType = field.elementType();
knnSimilarity = field.knnSimilarity();
knnIndexOptions = field.knnIndexOptions().length > 0 ? field.knnIndexOptions()[0] : null;
Assert.isTrue(field.enabled() || type == FieldType.Object, "enabled false is only allowed for field type object");
enabled = field.enabled();
eagerGlobalOrdinals = field.eagerGlobalOrdinals();
@ -217,6 +220,9 @@ public final class MappingParameters {
Assert.isTrue(dims >= 1 && dims <= 4096,
"Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 4096.");
}
elementType = field.elementType();
knnSimilarity = field.knnSimilarity();
knnIndexOptions = field.knnIndexOptions().length > 0 ? field.knnIndexOptions()[0] : null;
enabled = true;
eagerGlobalOrdinals = field.eagerGlobalOrdinals();
}
@ -356,6 +362,48 @@ public final class MappingParameters {
if (type == FieldType.Dense_Vector) {
objectNode.put(FIELD_PARAM_DIMS, dims);
if (!FieldElementType.DEFAULT.equals(elementType)) {
objectNode.put(FIELD_PARAM_ELEMENT_TYPE, elementType);
}
if (knnSimilarity != KnnSimilarity.DEFAULT) {
objectNode.put(FIELD_PARAM_SIMILARITY, knnSimilarity.getSimilarity());
}
if (knnSimilarity != KnnSimilarity.DEFAULT) {
Assert.isTrue(index, "knn similarity can only be specified when 'index' is true.");
objectNode.put(FIELD_PARAM_SIMILARITY, knnSimilarity.getSimilarity());
}
if (knnIndexOptions != null) {
Assert.isTrue(index, "knn index options can only be specified when 'index' is true.");
ObjectNode indexOptionsNode = objectNode.putObject(FIELD_PARAM_INDEX_OPTIONS);
KnnAlgorithmType algoType = knnIndexOptions.type();
if (algoType != KnnAlgorithmType.DEFAULT) {
if (algoType == KnnAlgorithmType.INT8_HNSW || algoType == KnnAlgorithmType.INT8_FLAT) {
Assert.isTrue(!FieldElementType.BYTE.equals(elementType),
"'element_type' can only be float when using vector quantization.");
}
indexOptionsNode.put(FIELD_PARAM_TYPE, algoType.getType());
}
if (knnIndexOptions.m() >= 0) {
Assert.isTrue(algoType == KnnAlgorithmType.HNSW || algoType == KnnAlgorithmType.INT8_HNSW,
"knn 'm' parameter can only be applicable to hnsw and int8_hnsw index types.");
indexOptionsNode.put(FIELD_PARAM_M, knnIndexOptions.m());
}
if (knnIndexOptions.efConstruction() >= 0) {
Assert.isTrue(algoType == KnnAlgorithmType.HNSW || algoType == KnnAlgorithmType.INT8_HNSW,
"knn 'ef_construction' can only be applicable to hnsw and int8_hnsw index types.");
indexOptionsNode.put(FIELD_PARAM_EF_CONSTRUCTION, knnIndexOptions.efConstruction());
}
if (knnIndexOptions.confidenceInterval() >= 0) {
Assert.isTrue(algoType == KnnAlgorithmType.INT8_HNSW
|| algoType == KnnAlgorithmType.INT8_FLAT,
"knn 'confidence_interval' can only be applicable to int8_hnsw and int8_flat index types.");
indexOptionsNode.put(FIELD_PARAM_CONFIDENCE_INTERVAL, knnIndexOptions.confidenceInterval());
}
}
}
if (!enabled) {

View File

@ -58,6 +58,7 @@ import org.springframework.lang.Nullable;
* @author Roman Puchkovskiy
* @author Brian Kimmig
* @author Morgan Lutz
* @author Haibo Liu
*/
@SpringIntegrationTest
public abstract class MappingBuilderIntegrationTests extends MappingContextBaseTests {
@ -908,7 +909,8 @@ public abstract class MappingBuilderIntegrationTests extends MappingContextBaseT
@Nullable
@Id private String id;
@Field(type = FieldType.Dense_Vector, dims = 42, similarity = "cosine") private double[] denseVector;
@Field(type = FieldType.Dense_Vector, dims = 42, knnSimilarity = KnnSimilarity.COSINE)
private double[] denseVector;
}
@Mapping(aliases = {

View File

@ -62,6 +62,7 @@ import org.springframework.lang.Nullable;
* @author Roman Puchkovskiy
* @author Brian Kimmig
* @author Morgan Lutz
* @author Haibo Liu
*/
public class MappingBuilderUnitTests extends MappingContextBaseTests {
@ -695,6 +696,32 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
assertEquals(expected, mapping, false);
}
@Test
@DisplayName("should write dense_vector properties for knn search")
void shouldWriteDenseVectorPropertiesWithKnnSearch() throws JSONException {
String expected = """
{
"properties":{
"my_vector":{
"type":"dense_vector",
"dims":16,
"element_type":"float",
"similarity":"dot_product",
"index_options":{
"type":"hnsw",
"m":16,
"ef_construction":100
}
}
}
}
""";
String mapping = getMappingBuilder().buildPropertyMapping(DenseVectorEntityWithKnnSearch.class);
assertEquals(expected, mapping, false);
}
@Test // #1370
@DisplayName("should not write mapping when enabled is false on entity")
void shouldNotWriteMappingWhenEnabledIsFalseOnEntity() throws JSONException {
@ -741,6 +768,14 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
.isInstanceOf(MappingException.class);
}
@Test
@DisplayName("should match confidence interval parameter for dense_vector type")
void shouldMatchConfidenceIntervalParameterForDenseVectorType() {
assertThatThrownBy(() -> getMappingBuilder().buildPropertyMapping(DenseVectorMisMatchConfidenceIntervalClass.class))
.isInstanceOf(IllegalArgumentException.class);
}
@Test // #1711
@DisplayName("should write typeHint entries")
void shouldWriteTypeHintEntries() throws JSONException {
@ -2063,6 +2098,36 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
}
}
@SuppressWarnings("unused")
static class DenseVectorEntityWithKnnSearch {
@Nullable
@Id private String id;
@Nullable
@Field(type = FieldType.Dense_Vector, dims = 16, elementType = FieldElementType.FLOAT,
knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, efConstruction = 100),
knnSimilarity = KnnSimilarity.DOT_PRODUCT)
private float[] my_vector;
@Nullable
public String getId() {
return id;
}
public void setId(@Nullable String id) {
this.id = id;
}
@Nullable
public float[] getMy_vector() {
return my_vector;
}
public void setMy_vector(@Nullable float[] my_vector) {
this.my_vector = my_vector;
}
}
@Mapping(enabled = false)
static class DisabledMappingEntity {
@Nullable
@ -2115,6 +2180,13 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
}
}
static class DenseVectorMisMatchConfidenceIntervalClass {
@Field(type = Dense_Vector, dims = 16, elementType = FieldElementType.FLOAT,
knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, confidenceInterval = 0.95F),
knnSimilarity = KnnSimilarity.DOT_PRODUCT)
private float[] dense_vector;
}
static class DisabledMappingProperty {
@Nullable
@Id private String id;

View File

@ -0,0 +1,44 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.repositories.knn;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.data.elasticsearch.junit.jupiter.ElasticsearchTemplateConfiguration;
import org.springframework.data.elasticsearch.repository.config.EnableElasticsearchRepositories;
import org.springframework.data.elasticsearch.utils.IndexNameProvider;
import org.springframework.test.context.ContextConfiguration;
/**
* @author Haibo Liu
* @since 5.4
*/
@ContextConfiguration(classes = { KnnSearchELCIntegrationTests.Config.class })
public class KnnSearchELCIntegrationTests extends KnnSearchIntegrationTests {
@Configuration
@Import({ ElasticsearchTemplateConfiguration.class })
@EnableElasticsearchRepositories(
basePackages = { "org.springframework.data.elasticsearch.repositories.knn" },
considerNestedRepositories = true)
static class Config {
@Bean
IndexNameProvider indexNameProvider() {
return new IndexNameProvider("knn-repository");
}
}
}

View File

@ -0,0 +1,179 @@
/*
* Copyright 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.elasticsearch.repositories.knn;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.*;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.annotation.Id;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.annotations.*;
import org.springframework.data.elasticsearch.client.elc.NativeQuery;
import org.springframework.data.elasticsearch.client.elc.NativeQueryBuilder;
import org.springframework.data.elasticsearch.core.ElasticsearchOperations;
import org.springframework.data.elasticsearch.core.SearchHit;
import org.springframework.data.elasticsearch.core.SearchHits;
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
import org.springframework.data.elasticsearch.junit.jupiter.SpringIntegrationTest;
import org.springframework.data.elasticsearch.repository.ElasticsearchRepository;
import org.springframework.data.elasticsearch.utils.IndexNameProvider;
import org.springframework.lang.Nullable;
/**
* @author Haibo Liu
* @since 5.4
*/
@SpringIntegrationTest
public abstract class KnnSearchIntegrationTests {
@Autowired ElasticsearchOperations operations;
@Autowired private IndexNameProvider indexNameProvider;
@Autowired private VectorEntityRepository vectorEntityRepository;
@BeforeEach
public void before() {
indexNameProvider.increment();
operations.indexOps(VectorEntity.class).createWithMapping();
}
@Test
@org.junit.jupiter.api.Order(java.lang.Integer.MAX_VALUE)
void cleanup() {
operations.indexOps(IndexCoordinates.of(indexNameProvider.getPrefix() + "*")).delete();
}
private List<VectorEntity> createVectorEntities(int n) {
List<VectorEntity> entities = new ArrayList<>();
float increment = 1.0f / n;
for (int i = 0; i < n; i++) {
VectorEntity entity = new VectorEntity();
entity.setId(UUID.randomUUID().toString());
entity.setMessage("top" + (i + 1));
// The generated vector is always in the first quadrant, from the x-axis direction to the y-axis direction
float[] vector = new float[] {1.0f - i * increment, increment};
entity.setVector(vector);
entities.add(entity);
}
return entities;
}
@Test
public void shouldReturnXAxisVector() {
// given
List<VectorEntity> entities = createVectorEntities(5);
vectorEntityRepository.saveAll(entities);
List<Float> xAxisVector = List.of(100f, 0f);
// when
NativeQuery query = new NativeQueryBuilder()
.withKnnSearches(ksb -> ksb.queryVector(xAxisVector).k(3L).field("vector"))
.withPageable(Pageable.ofSize(2))
.build();
SearchHits<VectorEntity> result = operations.search(query, VectorEntity.class);
List<VectorEntity> vectorEntities = result.getSearchHits().stream().map(SearchHit::getContent).toList();
// then
assertThat(result).isNotNull();
assertThat(result.getTotalHits()).isEqualTo(3L);
// should return the first vector, because it's near x-axis
assertThat(vectorEntities.get(0).getMessage()).isEqualTo("top1");
}
@Test
public void shouldReturnYAxisVector() {
// given
List<VectorEntity> entities = createVectorEntities(10);
vectorEntityRepository.saveAll(entities);
List<Float> yAxisVector = List.of(0f, 100f);
// when
NativeQuery query = new NativeQueryBuilder()
.withKnnSearches(ksb -> ksb.queryVector(yAxisVector).k(3L).field("vector"))
.withPageable(Pageable.ofSize(2))
.build();
SearchHits<VectorEntity> result = operations.search(query, VectorEntity.class);
List<VectorEntity> vectorEntities = result.getSearchHits().stream().map(SearchHit::getContent).toList();
// then
assertThat(result).isNotNull();
assertThat(result.getTotalHits()).isEqualTo(3L);
// should return the last vector, because it's near y-axis
assertThat(vectorEntities.get(0).getMessage()).isEqualTo("top10");
}
public interface VectorEntityRepository extends ElasticsearchRepository<VectorEntity, String> {
}
@Document(indexName = "#{@indexNameProvider.indexName()}")
static class VectorEntity {
@Nullable
@Id
private String id;
@Nullable
@Field(type = Keyword)
private String message;
// TODO: `elementType = FieldElementType.FLOAT,` is to be added here later
// TODO: element_type can not be set here, because it's left out in elasticsearch-specification
// TODO: the issue is fixed in https://github.com/elastic/elasticsearch-java/pull/800, but still not released in 8.13.x
// TODO: will be fixed later by either upgrading to 8.14.0 or a newer 8.13.x
@Field(type = FieldType.Dense_Vector, dims = 2,
knnIndexOptions = @KnnIndexOptions(type = KnnAlgorithmType.HNSW, m = 16, efConstruction = 100),
knnSimilarity = KnnSimilarity.COSINE)
private float[] vector;
@Nullable
public String getId() {
return id;
}
public void setId(@Nullable String id) {
this.id = id;
}
@Nullable
public String getMessage() {
return message;
}
public void setMessage(@Nullable String message) {
this.message = message;
}
@Nullable
public float[] getVector() {
return vector;
}
public void setVector(@Nullable float[] vector) {
this.vector = vector;
}
}
}