From 65a26a214fbcf00c9ef99e7c7c2a90eddffc41ab Mon Sep 17 00:00:00 2001 From: diego Date: Thu, 19 Dec 2024 16:24:40 +0000 Subject: [PATCH] HHH-18900 MariaDB Vector support + adding support and test correction for mariadb 11.6.2 snapshot isolation --- databases/mariadb/matrix.gradle | 2 +- docker_db.sh | 9 +- .../org/hibernate/dialect/MariaDBDialect.java | 46 +++++ .../java/org/hibernate/type/SqlTypes.java | 4 +- .../batch/BatchOptimisticLockingTest.java | 20 ++- .../test/batch/BatchUpdateAndVersionTest.java | 10 +- .../orm/test/jpa/lock/RepeatableReadTest.java | 3 + .../OptimisticAndPessimisticLockTest.java | 3 +- .../orm/test/locking/OptimisticLockTest.java | 6 +- .../orm/test/optlock/OptimisticLockTest.java | 7 +- .../vector/BinaryVectorJdbcType.java | 91 ++++++++++ .../vector/MariaDBFunctionContributor.java | 55 ++++++ .../vector/MariaDBTypeContributor.java | 69 ++++++++ ...g.hibernate.boot.model.FunctionContributor | 3 +- .../org.hibernate.boot.model.TypeContributor | 3 +- .../org/hibernate/vector/MariaDBTest.java | 167 ++++++++++++++++++ 16 files changed, 481 insertions(+), 17 deletions(-) create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java create mode 100644 hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java diff --git a/databases/mariadb/matrix.gradle b/databases/mariadb/matrix.gradle index f66f5cacb3..b72fdee035 100644 --- a/databases/mariadb/matrix.gradle +++ b/databases/mariadb/matrix.gradle @@ -4,4 +4,4 @@ * License: GNU Lesser General Public License (LGPL), version 2.1 or later. * See the lgpl.txt file in the root directory or . */ -jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.4.0' +jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.5.1' diff --git a/docker_db.sh b/docker_db.sh index 85538cd6a0..e6a78a7b2b 100755 --- a/docker_db.sh +++ b/docker_db.sh @@ -92,7 +92,7 @@ mysql_8_2() { } mariadb() { - mariadb_11_4 + mariadb_11_7 } mariadb_wait_until_start() @@ -138,6 +138,12 @@ mariadb_11_4() { mariadb_wait_until_start } +mariadb_11_7() { + $CONTAINER_CLI rm -f mariadb || true + $CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_11_7:-docker.io/mariadb:11.7-rc} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2 + mariadb_wait_until_start +} + mariadb_verylatest() { $CONTAINER_CLI rm -f mariadb || true $CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_VERYLATEST:-quay.io/mariadb-foundation/mariadb-devel:verylatest} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2 @@ -996,6 +1002,7 @@ if [ -z ${1} ]; then echo -e "\thana" echo -e "\tmariadb" echo -e "\tmariadb_verylatest" + echo -e "\tmariadb_11_7" echo -e "\tmariadb_11_4" echo -e "\tmariadb_11_1" echo -e "\tmariadb_10_11" diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java index d712767fc7..2d3736e398 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/MariaDBDialect.java @@ -8,6 +8,7 @@ import java.sql.DatabaseMetaData; import java.sql.SQLException; import java.sql.Types; +import org.hibernate.PessimisticLockException; import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.TypeContributions; import org.hibernate.dialect.aggregate.AggregateSupport; @@ -22,6 +23,10 @@ import org.hibernate.engine.jdbc.env.spi.IdentifierCaseStrategy; import org.hibernate.engine.jdbc.env.spi.IdentifierHelper; import org.hibernate.engine.jdbc.env.spi.IdentifierHelperBuilder; import org.hibernate.engine.spi.SessionFactoryImplementor; +import org.hibernate.exception.ConstraintViolationException; +import org.hibernate.exception.LockAcquisitionException; +import org.hibernate.exception.LockTimeoutException; +import org.hibernate.exception.spi.SQLExceptionConversionDelegate; import org.hibernate.query.sqm.CastType; import org.hibernate.service.ServiceRegistry; import org.hibernate.sql.ast.SqlAstTranslator; @@ -38,6 +43,7 @@ import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry; +import static org.hibernate.internal.util.JdbcExceptionHelper.extractSqlState; import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC; import static org.hibernate.type.SqlTypes.GEOMETRY; import static org.hibernate.type.SqlTypes.OTHER; @@ -315,6 +321,46 @@ public class MariaDBDialect extends MySQLDialect { return "dual"; } + @Override + public SQLExceptionConversionDelegate buildSQLExceptionConversionDelegate() { + return (sqlException, message, sql) -> { + switch ( sqlException.getErrorCode() ) { + // If @@innodb_snapshot_isolation is set (default since 11.6.2), + // if an attempt to acquire a lock on a record that does not exist in the current read view is made, + // an error DB_RECORD_CHANGED will be raised. + case 1020: + return new LockAcquisitionException( message, sqlException, sql ); + case 1205: + case 3572: + return new PessimisticLockException( message, sqlException, sql ); + case 1207: + case 1206: + return new LockAcquisitionException( message, sqlException, sql ); + case 1062: + // Unique constraint violation + return new ConstraintViolationException( + message, + sqlException, + sql, + ConstraintViolationException.ConstraintKind.UNIQUE, + getViolatedConstraintNameExtractor().extractConstraintName( sqlException ) + ); + } + + final String sqlState = extractSqlState( sqlException ); + if ( sqlState != null ) { + switch ( sqlState ) { + case "41000": + return new LockTimeoutException( message, sqlException, sql ); + case "40001": + return new LockAcquisitionException( message, sqlException, sql ); + } + } + + return null; + }; + } + @Override public boolean equivalentTypes(int typeCode1, int typeCode2) { return typeCode1 == Types.LONGVARCHAR && typeCode2 == SqlTypes.JSON diff --git a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java index d84e9061d6..772ecea406 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java @@ -682,8 +682,8 @@ public class SqlTypes { /** * A type code representing an {@code embedding vector} type for databases - * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and - * {@link org.hibernate.dialect.OracleDialect Oracle 23ai}. + * like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL}, + * {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}. * An embedding vector essentially is a {@code float[]} with a fixed size. * * @since 6.4 diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchOptimisticLockingTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchOptimisticLockingTest.java index f299c62442..1383a64cb7 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchOptimisticLockingTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchOptimisticLockingTest.java @@ -14,6 +14,7 @@ import jakarta.persistence.RollbackException; import org.hibernate.cfg.AvailableSettings; import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase; import org.junit.Test; @@ -69,7 +70,7 @@ public class BatchOptimisticLockingTest extends } ); try { - inTransaction( (session) -> { + inTransaction( session -> { List persons = session .createSelectionQuery( "select p from Person p", Person.class ) .getResultList(); @@ -107,10 +108,19 @@ public class BatchOptimisticLockingTest extends } else { assertEquals( OptimisticLockException.class, expected.getClass() ); - assertTrue( - expected.getMessage() - .startsWith("Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]") - ); + + if ( getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) { + assertTrue( + expected.getMessage() + .contains( "Record has changed since last read in table 'Person'" ) + ); + } else { + assertTrue( + expected.getMessage() + .startsWith( + "Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]" ) + ); + } } } } diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchUpdateAndVersionTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchUpdateAndVersionTest.java index 0a0a59282c..a6f96325f2 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchUpdateAndVersionTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/batch/BatchUpdateAndVersionTest.java @@ -13,6 +13,8 @@ import jakarta.persistence.RollbackException; import org.hibernate.StaleObjectStateException; import org.hibernate.cfg.AvailableSettings; +import org.hibernate.dialect.MariaDBDialect; +import org.hibernate.exception.LockAcquisitionException; import org.hibernate.exception.TransactionSerializationException; import org.hibernate.testing.orm.junit.JiraKey; import org.hibernate.testing.orm.junit.DomainModel; @@ -32,6 +34,7 @@ import jakarta.persistence.Table; import jakarta.persistence.Version; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.hibernate.testing.orm.junit.DialectContext.getDialect; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -136,7 +139,12 @@ public class BatchUpdateAndVersionTest { fail(); } catch (OptimisticLockException ole) { - assertTrue( ole.getCause() instanceof StaleObjectStateException ); + if (getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) { + // if @@innodb_snapshot_isolation is set, database throw an exception if record is not available anymore + assertTrue( ole.getCause() instanceof LockAcquisitionException ); + } else { + assertTrue( ole.getCause() instanceof StaleObjectStateException ); + } } //CockroachDB errors with a Serialization Exception catch (RollbackException rbe) { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/lock/RepeatableReadTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/lock/RepeatableReadTest.java index ecc80e3936..3e58315a57 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/lock/RepeatableReadTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/jpa/lock/RepeatableReadTest.java @@ -12,6 +12,7 @@ import org.hibernate.StaleObjectStateException; import org.hibernate.boot.registry.StandardServiceRegistryBuilder; import org.hibernate.cfg.AvailableSettings; import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider; import org.hibernate.exception.SQLGrammarException; @@ -105,6 +106,7 @@ public class RepeatableReadTest extends AbstractJPATest { @Test @SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row") + @SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed") public void testStaleVersionedInstanceFoundOnLock() { if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) { return; @@ -228,6 +230,7 @@ public class RepeatableReadTest extends AbstractJPATest { @Test @SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row") + @SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed") public void testStaleNonVersionedInstanceFoundOnLock() { if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) { return; diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticAndPessimisticLockTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticAndPessimisticLockTest.java index 97289f363a..d65c75c3b8 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticAndPessimisticLockTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticAndPessimisticLockTest.java @@ -10,7 +10,7 @@ import java.util.stream.Stream; import org.hibernate.LockMode; import org.hibernate.dialect.CockroachDialect; - +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.testing.orm.junit.DomainModel; import org.hibernate.testing.orm.junit.JiraKey; import org.hibernate.testing.orm.junit.SessionFactory; @@ -30,6 +30,7 @@ import jakarta.persistence.Version; @SessionFactory @JiraKey("HHH-16461") @SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB uses SERIALIZABLE isolation, and does not support this") +@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed") public class OptimisticAndPessimisticLockTest { public Stream pessimisticLockModes() { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticLockTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticLockTest.java index 55b7107116..fb35a0f8db 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticLockTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/locking/OptimisticLockTest.java @@ -10,8 +10,9 @@ import jakarta.persistence.Id; import jakarta.persistence.Version; import org.hibernate.annotations.OptimisticLock; import org.hibernate.dialect.CockroachDialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.orm.test.jpa.BaseEntityManagerFunctionalTestCase; -import org.hibernate.testing.SkipForDialect; +import org.hibernate.testing.orm.junit.SkipForDialect; import org.junit.Test; import static org.hibernate.testing.transaction.TransactionUtil.doInJPA; @@ -29,7 +30,8 @@ public class OptimisticLockTest extends BaseEntityManagerFunctionalTestCase { } @Test - @SkipForDialect(value = CockroachDialect.class, comment = "Fails at SERIALIZABLE isolation") + @SkipForDialect(dialectClass = CockroachDialect.class, reason = "Fails at SERIALIZABLE isolation") + @SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed") public void test() { doInJPA(this::entityManagerFactory, entityManager -> { Phone phone = new Phone(); diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/optlock/OptimisticLockTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/optlock/OptimisticLockTest.java index 3288dee2b5..127db7ac10 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/optlock/OptimisticLockTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/optlock/OptimisticLockTest.java @@ -13,6 +13,7 @@ import org.hibernate.StaleObjectStateException; import org.hibernate.StaleStateException; import org.hibernate.dialect.CockroachDialect; import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MariaDBDialect; import org.hibernate.dialect.SQLServerDialect; import org.hibernate.testing.orm.junit.DialectFeatureChecks; @@ -23,6 +24,7 @@ import org.hibernate.testing.orm.junit.SessionFactoryScope; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import static org.hibernate.testing.orm.junit.DialectContext.getDialect; import static org.junit.jupiter.api.Assertions.fail; /** @@ -189,8 +191,9 @@ public class OptimisticLockTest { "40001" ) ) { // CockroachDB always runs in SERIALIZABLE isolation, and uses SQL state 40001 to indicate // serialization failure. - } - else { + } else if (dialect instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) { + // Mariadb snapshot_isolation throws error + } else { throw e; } } diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java new file mode 100644 index 0000000000..a9d961e62c --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/BinaryVectorJdbcType.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.spi.TypeConfiguration; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +public class BinaryVectorJdbcType extends ArrayJdbcType { + + public BinaryVectorJdbcType(JdbcType elementJdbcType) { + super( elementJdbcType ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR; + } + + @Override + public JavaType getJdbcRecommendedJavaTypeMapping( + Integer precision, + Integer scale, + TypeConfiguration typeConfiguration) { + return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class ); + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( writeExpression ); + } + + @Override + public ValueExtractor getExtractor(JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options ); + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException { + return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options ); + } + + }; + } + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException { + st.setObject( index, value ); + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + st.setObject( name, value, java.sql.Types.ARRAY ); + } + + @Override + public Object getBindValue(X value, WrapperOptions options) { + return value; + } + }; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java new file mode 100644 index 0000000000..670238c6be --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java @@ -0,0 +1,55 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MariaDBDialect; +import org.hibernate.query.sqm.function.SqmFunctionRegistry; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.spi.TypeConfiguration; + +public class MariaDBFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry(); + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof MariaDBDialect ) { + final BasicType doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ); + + functionRegistry.patternDescriptorBuilder( "cosine_distance", "vec_distance_cosine(?1,?2)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vec_distance_euclidean(?1,?2)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" ); + + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java new file mode 100644 index 0000000000..08452406f6 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/MariaDBTypeContributor.java @@ -0,0 +1,69 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.MariaDBDialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; +import org.hibernate.type.spi.TypeConfiguration; + +import java.lang.reflect.Type; + +public class MariaDBTypeContributor implements TypeContributor { + + private static final Type[] VECTOR_JAVA_TYPES = { + Float[].class, + float[].class + }; + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect(); + if ( dialect instanceof MariaDBDialect ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final BasicType floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ); + final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType ); + for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) { + basicTypeRegistry.register( + new BasicArrayType<>( + floatBasicType, + vectorJdbcType, + javaTypeRegistry.getDescriptor( vectorJavaType ) + ), + StandardBasicTypes.VECTOR.getName() + ); + } + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) { + @Override + public String getTypeName(Size size) { + return getTypeName( + size.getArrayLength() == null ? null : size.getArrayLength().longValue(), + null, + null + ); + } + } + ); + } + } +} diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor index 477fddb117..6103956ccb 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor @@ -1,2 +1,3 @@ org.hibernate.vector.PGVectorFunctionContributor -org.hibernate.vector.OracleVectorFunctionContributor \ No newline at end of file +org.hibernate.vector.OracleVectorFunctionContributor +org.hibernate.vector.MariaDBFunctionContributor diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor index eeaa217a75..11605464c8 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor @@ -1,2 +1,3 @@ org.hibernate.vector.PGVectorTypeContributor -org.hibernate.vector.OracleVectorTypeContributor \ No newline at end of file +org.hibernate.vector.OracleVectorTypeContributor +org.hibernate.vector.MariaDBTypeContributor diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java new file mode 100644 index 0000000000..814ae22962 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/MariaDBTest.java @@ -0,0 +1,167 @@ +/* + * SPDX-License-Identifier: LGPL-2.1-or-later + * Copyright Red Hat Inc. and Hibernate Authors + */ +package org.hibernate.vector; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.MariaDBDialect; +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.hibernate.type.SqlTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author Diego Dupin + */ +@DomainModel(annotatedClasses = MariaDBTest.VectorEntity.class) +@SessionFactory +@RequiresDialect(value = MariaDBDialect.class, matchSubTypes = false, majorVersion = 11, minorVersion = 7) +public class MariaDBTest { + + private static final float[] V1 = new float[]{ 1, 2, 3 }; + private static final float[] V2 = new float[]{ 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Test + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::cosine-distance-example[] + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::cosine-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D ); + } ); + } + + @Test + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-distance-example[] + final float[] vector = new float[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D ); + } ); + } + + private static double cosineDistance(float[] f1, float[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + private static double euclideanDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + private static double innerProduct(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + private static double euclideanNorm(float[] f) { + double result = 0; + for ( float v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + //tag::usage-example[] + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.VECTOR) + @Array(length = 3) + private float[] theVector; + //end::usage-example[] + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +}