HHH-18900 MariaDB Vector support

+ adding support and test correction for mariadb 11.6.2 snapshot isolation
This commit is contained in:
diego 2024-12-19 16:24:40 +00:00 committed by Christian Beikov
parent d22aeb1a52
commit 65a26a214f
16 changed files with 481 additions and 17 deletions

View File

@ -4,4 +4,4 @@
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.4.0'
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.5.1'

View File

@ -92,7 +92,7 @@ mysql_8_2() {
}
mariadb() {
mariadb_11_4
mariadb_11_7
}
mariadb_wait_until_start()
@ -138,6 +138,12 @@ mariadb_11_4() {
mariadb_wait_until_start
}
mariadb_11_7() {
$CONTAINER_CLI rm -f mariadb || true
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_11_7:-docker.io/mariadb:11.7-rc} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
mariadb_wait_until_start
}
mariadb_verylatest() {
$CONTAINER_CLI rm -f mariadb || true
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_VERYLATEST:-quay.io/mariadb-foundation/mariadb-devel:verylatest} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
@ -996,6 +1002,7 @@ if [ -z ${1} ]; then
echo -e "\thana"
echo -e "\tmariadb"
echo -e "\tmariadb_verylatest"
echo -e "\tmariadb_11_7"
echo -e "\tmariadb_11_4"
echo -e "\tmariadb_11_1"
echo -e "\tmariadb_10_11"

View File

@ -8,6 +8,7 @@ import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.sql.Types;
import org.hibernate.PessimisticLockException;
import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.dialect.aggregate.AggregateSupport;
@ -22,6 +23,10 @@ import org.hibernate.engine.jdbc.env.spi.IdentifierCaseStrategy;
import org.hibernate.engine.jdbc.env.spi.IdentifierHelper;
import org.hibernate.engine.jdbc.env.spi.IdentifierHelperBuilder;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.exception.ConstraintViolationException;
import org.hibernate.exception.LockAcquisitionException;
import org.hibernate.exception.LockTimeoutException;
import org.hibernate.exception.spi.SQLExceptionConversionDelegate;
import org.hibernate.query.sqm.CastType;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.sql.ast.SqlAstTranslator;
@ -38,6 +43,7 @@ import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry;
import static org.hibernate.internal.util.JdbcExceptionHelper.extractSqlState;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
import static org.hibernate.type.SqlTypes.GEOMETRY;
import static org.hibernate.type.SqlTypes.OTHER;
@ -315,6 +321,46 @@ public class MariaDBDialect extends MySQLDialect {
return "dual";
}
@Override
public SQLExceptionConversionDelegate buildSQLExceptionConversionDelegate() {
return (sqlException, message, sql) -> {
switch ( sqlException.getErrorCode() ) {
// If @@innodb_snapshot_isolation is set (default since 11.6.2),
// if an attempt to acquire a lock on a record that does not exist in the current read view is made,
// an error DB_RECORD_CHANGED will be raised.
case 1020:
return new LockAcquisitionException( message, sqlException, sql );
case 1205:
case 3572:
return new PessimisticLockException( message, sqlException, sql );
case 1207:
case 1206:
return new LockAcquisitionException( message, sqlException, sql );
case 1062:
// Unique constraint violation
return new ConstraintViolationException(
message,
sqlException,
sql,
ConstraintViolationException.ConstraintKind.UNIQUE,
getViolatedConstraintNameExtractor().extractConstraintName( sqlException )
);
}
final String sqlState = extractSqlState( sqlException );
if ( sqlState != null ) {
switch ( sqlState ) {
case "41000":
return new LockTimeoutException( message, sqlException, sql );
case "40001":
return new LockAcquisitionException( message, sqlException, sql );
}
}
return null;
};
}
@Override
public boolean equivalentTypes(int typeCode1, int typeCode2) {
return typeCode1 == Types.LONGVARCHAR && typeCode2 == SqlTypes.JSON

View File

@ -682,8 +682,8 @@ public class SqlTypes {
/**
* A type code representing an {@code embedding vector} type for databases
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL},
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}.
* An embedding vector essentially is a {@code float[]} with a fixed size.
*
* @since 6.4

View File

@ -14,6 +14,7 @@ import jakarta.persistence.RollbackException;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
import org.junit.Test;
@ -69,7 +70,7 @@ public class BatchOptimisticLockingTest extends
} );
try {
inTransaction( (session) -> {
inTransaction( session -> {
List<Person> persons = session
.createSelectionQuery( "select p from Person p", Person.class )
.getResultList();
@ -107,10 +108,19 @@ public class BatchOptimisticLockingTest extends
}
else {
assertEquals( OptimisticLockException.class, expected.getClass() );
assertTrue(
expected.getMessage()
.startsWith("Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]")
);
if ( getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
assertTrue(
expected.getMessage()
.contains( "Record has changed since last read in table 'Person'" )
);
} else {
assertTrue(
expected.getMessage()
.startsWith(
"Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]" )
);
}
}
}
}

View File

@ -13,6 +13,8 @@ import jakarta.persistence.RollbackException;
import org.hibernate.StaleObjectStateException;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.exception.LockAcquisitionException;
import org.hibernate.exception.TransactionSerializationException;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.DomainModel;
@ -32,6 +34,7 @@ import jakarta.persistence.Table;
import jakarta.persistence.Version;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.hibernate.testing.orm.junit.DialectContext.getDialect;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@ -136,7 +139,12 @@ public class BatchUpdateAndVersionTest {
fail();
}
catch (OptimisticLockException ole) {
assertTrue( ole.getCause() instanceof StaleObjectStateException );
if (getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
// if @@innodb_snapshot_isolation is set, database throw an exception if record is not available anymore
assertTrue( ole.getCause() instanceof LockAcquisitionException );
} else {
assertTrue( ole.getCause() instanceof StaleObjectStateException );
}
}
//CockroachDB errors with a Serialization Exception
catch (RollbackException rbe) {

View File

@ -12,6 +12,7 @@ import org.hibernate.StaleObjectStateException;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.dialect.SQLServerDialect;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.hibernate.exception.SQLGrammarException;
@ -105,6 +106,7 @@ public class RepeatableReadTest extends AbstractJPATest {
@Test
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row")
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
public void testStaleVersionedInstanceFoundOnLock() {
if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) {
return;
@ -228,6 +230,7 @@ public class RepeatableReadTest extends AbstractJPATest {
@Test
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row")
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
public void testStaleNonVersionedInstanceFoundOnLock() {
if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) {
return;

View File

@ -10,7 +10,7 @@ import java.util.stream.Stream;
import org.hibernate.LockMode;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.SessionFactory;
@ -30,6 +30,7 @@ import jakarta.persistence.Version;
@SessionFactory
@JiraKey("HHH-16461")
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB uses SERIALIZABLE isolation, and does not support this")
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
public class OptimisticAndPessimisticLockTest {
public Stream<LockMode> pessimisticLockModes() {

View File

@ -10,8 +10,9 @@ import jakarta.persistence.Id;
import jakarta.persistence.Version;
import org.hibernate.annotations.OptimisticLock;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.orm.test.jpa.BaseEntityManagerFunctionalTestCase;
import org.hibernate.testing.SkipForDialect;
import org.hibernate.testing.orm.junit.SkipForDialect;
import org.junit.Test;
import static org.hibernate.testing.transaction.TransactionUtil.doInJPA;
@ -29,7 +30,8 @@ public class OptimisticLockTest extends BaseEntityManagerFunctionalTestCase {
}
@Test
@SkipForDialect(value = CockroachDialect.class, comment = "Fails at SERIALIZABLE isolation")
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Fails at SERIALIZABLE isolation")
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
public void test() {
doInJPA(this::entityManagerFactory, entityManager -> {
Phone phone = new Phone();

View File

@ -13,6 +13,7 @@ import org.hibernate.StaleObjectStateException;
import org.hibernate.StaleStateException;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.dialect.SQLServerDialect;
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
@ -23,6 +24,7 @@ import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import static org.hibernate.testing.orm.junit.DialectContext.getDialect;
import static org.junit.jupiter.api.Assertions.fail;
/**
@ -189,8 +191,9 @@ public class OptimisticLockTest {
"40001" ) ) {
// CockroachDB always runs in SERIALIZABLE isolation, and uses SQL state 40001 to indicate
// serialization failure.
}
else {
} else if (dialect instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
// Mariadb snapshot_isolation throws error
} else {
throw e;
}
}

View File

@ -0,0 +1,91 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.vector;
import org.hibernate.dialect.Dialect;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.ValueBinder;
import org.hibernate.type.descriptor.ValueExtractor;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.BasicBinder;
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.spi.TypeConfiguration;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
public class BinaryVectorJdbcType extends ArrayJdbcType {
public BinaryVectorJdbcType(JdbcType elementJdbcType) {
super( elementJdbcType );
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR;
}
@Override
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
Integer precision,
Integer scale,
TypeConfiguration typeConfiguration) {
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( writeExpression );
}
@Override
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
return new BasicExtractor<>( javaTypeDescriptor, this ) {
@Override
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options );
}
@Override
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options );
}
@Override
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options );
}
};
}
@Override
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
return new BasicBinder<>( javaTypeDescriptor, this ) {
@Override
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException {
st.setObject( index, value );
}
@Override
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
throws SQLException {
st.setObject( name, value, java.sql.Types.ARRAY );
}
@Override
public Object getBindValue(X value, WrapperOptions options) {
return value;
}
};
}
}

View File

@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.vector;
import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
public class MariaDBFunctionContributor implements FunctionContributor {
@Override
public void contributeFunctions(FunctionContributions functionContributions) {
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final Dialect dialect = functionContributions.getDialect();
if ( dialect instanceof MariaDBDialect ) {
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vec_distance_cosine(?1,?2)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vec_distance_euclidean(?1,?2)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
}
}
@Override
public int ordinal() {
return 200;
}
}

View File

@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.vector;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.boot.model.TypeContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.type.BasicArrayType;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
import org.hibernate.type.spi.TypeConfiguration;
import java.lang.reflect.Type;
public class MariaDBTypeContributor implements TypeContributor {
private static final Type[] VECTOR_JAVA_TYPES = {
Float[].class,
float[].class
};
@Override
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
if ( dialect instanceof MariaDBDialect ) {
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) );
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType );
for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) {
basicTypeRegistry.register(
new BasicArrayType<>(
floatBasicType,
vectorJdbcType,
javaTypeRegistry.getDescriptor( vectorJavaType )
),
StandardBasicTypes.VECTOR.getName()
);
}
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return getTypeName(
size.getArrayLength() == null ? null : size.getArrayLength().longValue(),
null,
null
);
}
}
);
}
}
}

View File

@ -1,2 +1,3 @@
org.hibernate.vector.PGVectorFunctionContributor
org.hibernate.vector.OracleVectorFunctionContributor
org.hibernate.vector.MariaDBFunctionContributor

View File

@ -1,2 +1,3 @@
org.hibernate.vector.PGVectorTypeContributor
org.hibernate.vector.OracleVectorTypeContributor
org.hibernate.vector.MariaDBTypeContributor

View File

@ -0,0 +1,167 @@
/*
* SPDX-License-Identifier: LGPL-2.1-or-later
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.vector;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.MariaDBDialect;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.hibernate.type.SqlTypes;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Diego Dupin
*/
@DomainModel(annotatedClasses = MariaDBTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = MariaDBDialect.class, matchSubTypes = false, majorVersion = 11, minorVersion = 7)
public class MariaDBTest {
private static final float[] V1 = new float[]{ 1, 2, 3 };
private static final float[] V2 = new float[]{ 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
} );
}
private static double cosineDistance(float[] f1, float[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double innerProduct(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
private static double euclideanNorm(float[] f) {
double result = 0;
for ( float v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
@Entity( name = "VectorEntity" )
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column( name = "the_vector" )
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 3)
private float[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, float[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public float[] getTheVector() {
return theVector;
}
public void setTheVector(float[] theVector) {
this.theVector = theVector;
}
}
}