Add support for dense_vector type

Original Pull Request  #1708
Closes #1700
This commit is contained in:
Morgan 2021-02-25 23:25:27 -08:00 committed by GitHub
parent 8da718e41a
commit 3f2ab4b06a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 115 additions and 1 deletions

View File

@ -34,6 +34,8 @@ import org.springframework.core.annotation.AliasFor;
* @author Peter-Josef Meisch
* @author Xiao Yu
* @author Aleksei Arsenev
* @author Brian Kimmig
* @author Morgan Lutz
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ ElementType.FIELD, ElementType.ANNOTATION_TYPE })
@ -185,4 +187,11 @@ public @interface Field {
* @since 4.1
*/
NullValueType nullValueType() default NullValueType.String;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 4.2
*/
int dims() default -1;
}

View File

@ -22,6 +22,8 @@ package org.springframework.data.elasticsearch.annotations;
* @author Zeng Zetang
* @author Peter-Josef Meisch
* @author Aleksei Arsenev
* @author Brian Kimmig
* @author Morgan Lutz
*/
public enum FieldType {
Auto, //
@ -57,5 +59,7 @@ public enum FieldType {
/** @since 4.1 */
Rank_Features, //
/** since 4.2 */
Wildcard //
Wildcard, //
/** @since 4.2 */
Dense_Vector //
}

View File

@ -27,6 +27,8 @@ import java.lang.annotation.Target;
* @author Xiao Yu
* @author Peter-Josef Meisch
* @author Aleksei Arsenev
* @author Brian Kimmig
* @author Morgan Lutz
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.ANNOTATION_TYPE)
@ -140,4 +142,11 @@ public @interface InnerField {
* @since 4.1
*/
NullValueType nullValueType() default NullValueType.String;
/**
* to be used in combination with {@link FieldType#Dense_Vector}
*
* @since 4.2
*/
int dims() default -1;
}

View File

@ -39,6 +39,8 @@ import org.springframework.util.StringUtils;
*
* @author Peter-Josef Meisch
* @author Aleksei Arsenev
* @author Brian Kimmig
* @author Morgan Lutz
* @since 4.0
*/
public final class MappingParameters {
@ -65,6 +67,7 @@ public final class MappingParameters {
static final String FIELD_PARAM_NULL_VALUE = "null_value";
static final String FIELD_PARAM_POSITION_INCREMENT_GAP = "position_increment_gap";
static final String FIELD_PARAM_POSITIVE_SCORE_IMPACT = "positive_score_impact";
static final String FIELD_PARAM_DIMS = "dims";
static final String FIELD_PARAM_SCALING_FACTOR = "scaling_factor";
static final String FIELD_PARAM_SEARCH_ANALYZER = "search_analyzer";
static final String FIELD_PARAM_STORE = "store";
@ -94,6 +97,7 @@ public final class MappingParameters {
private final NullValueType nullValueType;
private final Integer positionIncrementGap;
private final boolean positiveScoreImpact;
private final Integer dims;
private final String searchAnalyzer;
private final double scalingFactor;
private final Similarity similarity;
@ -153,6 +157,10 @@ public final class MappingParameters {
|| (maxShingleSize >= 2 && maxShingleSize <= 4), //
"maxShingleSize must be in inclusive range from 2 to 4 for field type search_as_you_type");
positiveScoreImpact = field.positiveScoreImpact();
dims = field.dims();
if (type == FieldType.Dense_Vector) {
Assert.isTrue(dims >= 1 && dims <= 2048, "Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 2048.");
}
Assert.isTrue(field.enabled() || type == FieldType.Object, "enabled false is only allowed for field type object");
enabled = field.enabled();
eagerGlobalOrdinals = field.eagerGlobalOrdinals();
@ -191,6 +199,10 @@ public final class MappingParameters {
|| (maxShingleSize >= 2 && maxShingleSize <= 4), //
"maxShingleSize must be in inclusive range from 2 to 4 for field type search_as_you_type");
positiveScoreImpact = field.positiveScoreImpact();
dims = field.dims();
if (type == FieldType.Dense_Vector) {
Assert.isTrue(dims >= 1 && dims <= 2048, "Invalid required parameter! Dense_Vector value \"dims\" must be between 1 and 2048.");
}
enabled = true;
eagerGlobalOrdinals = field.eagerGlobalOrdinals();
}
@ -323,6 +335,10 @@ public final class MappingParameters {
builder.field(FIELD_PARAM_POSITIVE_SCORE_IMPACT, positiveScoreImpact);
}
if (type == FieldType.Dense_Vector) {
builder.field(FIELD_PARAM_DIMS, dims);
}
if (!enabled) {
builder.field(FIELD_PARAM_ENABLED, enabled);
}

View File

@ -82,6 +82,8 @@ import org.springframework.test.context.ContextConfiguration;
* @author Peter-Josef Meisch
* @author Xiao Yu
* @author Roman Puchkovskiy
* @author Brian Kimmig
* @author Morgan Lutz
*/
@SpringIntegrationTest
@ContextConfiguration(classes = { ElasticsearchRestTemplateConfiguration.class })
@ -271,6 +273,16 @@ public class MappingBuilderIntegrationTests extends MappingContextBaseTests {
indexOps.putMapping();
}
@Test // #1700
@DisplayName("should write dense_vector field mapping")
void shouldWriteDenseVectorFieldMapping() {
IndexOperations indexOps = operations.indexOps(DenseVectorEntity.class);
indexOps.create();
indexOps.putMapping();
indexOps.delete();
}
@Test // #1370
@DisplayName("should write mapping for disabled entity")
void shouldWriteMappingForDisabledEntity() {
@ -657,4 +669,11 @@ public class MappingBuilderIntegrationTests extends MappingContextBaseTests {
@Field(type = Text) private String text;
@Mapping(enabled = false) @Field(type = Object) private Object object;
}
@Data
@Document(indexName = "densevector-test")
static class DenseVectorEntity {
@Id private String id;
@Field(type = Dense_Vector, dims = 3) private float[] dense_vector;
}
}

View File

@ -69,6 +69,8 @@ import org.springframework.lang.Nullable;
* @author Peter-Josef Meisch
* @author Xiao Yu
* @author Roman Puchkovskiy
* @author Brian Kimmig
* @author Morgan Lutz
*/
public class MappingBuilderUnitTests extends MappingContextBaseTests {
@ -506,6 +508,23 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
assertEquals(expected, mapping, false);
}
@Test // #1700
@DisplayName("should write dense_vector properties")
void shouldWriteDenseVectorProperties() throws JSONException {
String expected = "{\n" + //
" \"properties\": {\n" + //
" \"my_vector\": {\n" + //
" \"type\": \"dense_vector\",\n" + //
" \"dims\": 16\n" + //
" }\n" + //
" }\n" + //
"}\n"; //
String mapping = getMappingBuilder().buildPropertyMapping(DenseVectorEntity.class);
assertEquals(expected, mapping, false);
}
@Test // #1370
@DisplayName("should not write mapping when enabled is false on entity")
void shouldNotWriteMappingWhenEnabledIsFalseOnEntity() throws JSONException {
@ -963,6 +982,13 @@ public class MappingBuilderUnitTests extends MappingContextBaseTests {
@Field(type = FieldType.Rank_Features) private Map<String, Integer> topics;
}
@Data
static class DenseVectorEntity {
@Id private String id;
@Field(type = FieldType.Dense_Vector, dims = 16) private float[] my_vector;
}
@Data
@Mapping(enabled = false)
static class DisabledMappingEntity {

View File

@ -1,6 +1,7 @@
package org.springframework.data.elasticsearch.core.index;
import static org.assertj.core.api.Assertions.*;
import static org.springframework.data.elasticsearch.annotations.FieldType.Dense_Vector;
import static org.springframework.data.elasticsearch.annotations.FieldType.Object;
import java.lang.annotation.Annotation;
@ -17,6 +18,8 @@ import org.springframework.lang.Nullable;
/**
* @author Peter-Josef Meisch
* @author Brian Kimmig
* @author Morgan Lutz
*/
public class MappingParametersTest extends MappingContextBaseTests {
@ -66,6 +69,26 @@ public class MappingParametersTest extends MappingContextBaseTests {
assertThatThrownBy(() -> MappingParameters.from(annotation)).isInstanceOf(IllegalArgumentException.class);
}
@Test // #1700
@DisplayName("should not allow dims length greater than 2048 for dense_vector type")
void shouldNotAllowDimsLengthGreaterThan2048ForDenseVectorType() {
ElasticsearchPersistentEntity<?> failEntity = elasticsearchConverter.get().getMappingContext()
.getRequiredPersistentEntity(DenseVectorInvalidDimsClass.class);
Annotation annotation = failEntity.getRequiredPersistentProperty("dense_vector").findAnnotation(Field.class);
assertThatThrownBy(() -> MappingParameters.from(annotation)).isInstanceOf(IllegalArgumentException.class);
}
@Test // #1700
@DisplayName("should require dims parameter for dense_vector type")
void shouldRequireDimsParameterForDenseVectorType() {
ElasticsearchPersistentEntity<?> failEntity = elasticsearchConverter.get().getMappingContext()
.getRequiredPersistentEntity(DenseVectorMissingDimsClass.class);
Annotation annotation = failEntity.getRequiredPersistentProperty("dense_vector").findAnnotation(Field.class);
assertThatThrownBy(() -> MappingParameters.from(annotation)).isInstanceOf(IllegalArgumentException.class);
}
static class AnnotatedClass {
@Nullable @Field private String field;
@Nullable @MultiField(mainField = @Field,
@ -79,4 +102,12 @@ public class MappingParametersTest extends MappingContextBaseTests {
static class InvalidEnabledFieldClass {
@Nullable @Field(type = FieldType.Text, enabled = false) private String disabledObject;
}
static class DenseVectorInvalidDimsClass {
@Field(type = Dense_Vector, dims = 2049) private float[] dense_vector;
}
static class DenseVectorMissingDimsClass {
@Field(type = Dense_Vector) private float[] dense_vector;
}
}