HHH-18900 MariaDB Vector support
+ adding support and test correction for mariadb 11.6.2 snapshot isolation
This commit is contained in:
parent
d22aeb1a52
commit
65a26a214f
|
@ -4,4 +4,4 @@
|
|||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.4.0'
|
||||
jdbcDependency 'org.mariadb.jdbc:mariadb-java-client:3.5.1'
|
||||
|
|
|
@ -92,7 +92,7 @@ mysql_8_2() {
|
|||
}
|
||||
|
||||
mariadb() {
|
||||
mariadb_11_4
|
||||
mariadb_11_7
|
||||
}
|
||||
|
||||
mariadb_wait_until_start()
|
||||
|
@ -138,6 +138,12 @@ mariadb_11_4() {
|
|||
mariadb_wait_until_start
|
||||
}
|
||||
|
||||
mariadb_11_7() {
|
||||
$CONTAINER_CLI rm -f mariadb || true
|
||||
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_11_7:-docker.io/mariadb:11.7-rc} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
|
||||
mariadb_wait_until_start
|
||||
}
|
||||
|
||||
mariadb_verylatest() {
|
||||
$CONTAINER_CLI rm -f mariadb || true
|
||||
$CONTAINER_CLI run --name mariadb -e MARIADB_USER=hibernate_orm_test -e MARIADB_PASSWORD=hibernate_orm_test -e MARIADB_DATABASE=hibernate_orm_test -e MARIADB_ROOT_PASSWORD=hibernate_orm_test -p3306:3306 -d ${DB_IMAGE_MARIADB_VERYLATEST:-quay.io/mariadb-foundation/mariadb-devel:verylatest} --character-set-server=utf8mb4 --collation-server=utf8mb4_bin --skip-character-set-client-handshake --lower_case_table_names=2
|
||||
|
@ -996,6 +1002,7 @@ if [ -z ${1} ]; then
|
|||
echo -e "\thana"
|
||||
echo -e "\tmariadb"
|
||||
echo -e "\tmariadb_verylatest"
|
||||
echo -e "\tmariadb_11_7"
|
||||
echo -e "\tmariadb_11_4"
|
||||
echo -e "\tmariadb_11_1"
|
||||
echo -e "\tmariadb_10_11"
|
||||
|
|
|
@ -8,6 +8,7 @@ import java.sql.DatabaseMetaData;
|
|||
import java.sql.SQLException;
|
||||
import java.sql.Types;
|
||||
|
||||
import org.hibernate.PessimisticLockException;
|
||||
import org.hibernate.boot.model.FunctionContributions;
|
||||
import org.hibernate.boot.model.TypeContributions;
|
||||
import org.hibernate.dialect.aggregate.AggregateSupport;
|
||||
|
@ -22,6 +23,10 @@ import org.hibernate.engine.jdbc.env.spi.IdentifierCaseStrategy;
|
|||
import org.hibernate.engine.jdbc.env.spi.IdentifierHelper;
|
||||
import org.hibernate.engine.jdbc.env.spi.IdentifierHelperBuilder;
|
||||
import org.hibernate.engine.spi.SessionFactoryImplementor;
|
||||
import org.hibernate.exception.ConstraintViolationException;
|
||||
import org.hibernate.exception.LockAcquisitionException;
|
||||
import org.hibernate.exception.LockTimeoutException;
|
||||
import org.hibernate.exception.spi.SQLExceptionConversionDelegate;
|
||||
import org.hibernate.query.sqm.CastType;
|
||||
import org.hibernate.service.ServiceRegistry;
|
||||
import org.hibernate.sql.ast.SqlAstTranslator;
|
||||
|
@ -38,6 +43,7 @@ import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
|
|||
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
|
||||
import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry;
|
||||
|
||||
import static org.hibernate.internal.util.JdbcExceptionHelper.extractSqlState;
|
||||
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
|
||||
import static org.hibernate.type.SqlTypes.GEOMETRY;
|
||||
import static org.hibernate.type.SqlTypes.OTHER;
|
||||
|
@ -315,6 +321,46 @@ public class MariaDBDialect extends MySQLDialect {
|
|||
return "dual";
|
||||
}
|
||||
|
||||
@Override
|
||||
public SQLExceptionConversionDelegate buildSQLExceptionConversionDelegate() {
|
||||
return (sqlException, message, sql) -> {
|
||||
switch ( sqlException.getErrorCode() ) {
|
||||
// If @@innodb_snapshot_isolation is set (default since 11.6.2),
|
||||
// if an attempt to acquire a lock on a record that does not exist in the current read view is made,
|
||||
// an error DB_RECORD_CHANGED will be raised.
|
||||
case 1020:
|
||||
return new LockAcquisitionException( message, sqlException, sql );
|
||||
case 1205:
|
||||
case 3572:
|
||||
return new PessimisticLockException( message, sqlException, sql );
|
||||
case 1207:
|
||||
case 1206:
|
||||
return new LockAcquisitionException( message, sqlException, sql );
|
||||
case 1062:
|
||||
// Unique constraint violation
|
||||
return new ConstraintViolationException(
|
||||
message,
|
||||
sqlException,
|
||||
sql,
|
||||
ConstraintViolationException.ConstraintKind.UNIQUE,
|
||||
getViolatedConstraintNameExtractor().extractConstraintName( sqlException )
|
||||
);
|
||||
}
|
||||
|
||||
final String sqlState = extractSqlState( sqlException );
|
||||
if ( sqlState != null ) {
|
||||
switch ( sqlState ) {
|
||||
case "41000":
|
||||
return new LockTimeoutException( message, sqlException, sql );
|
||||
case "40001":
|
||||
return new LockAcquisitionException( message, sqlException, sql );
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equivalentTypes(int typeCode1, int typeCode2) {
|
||||
return typeCode1 == Types.LONGVARCHAR && typeCode2 == SqlTypes.JSON
|
||||
|
|
|
@ -682,8 +682,8 @@ public class SqlTypes {
|
|||
|
||||
/**
|
||||
* A type code representing an {@code embedding vector} type for databases
|
||||
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and
|
||||
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
|
||||
* like {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL},
|
||||
* {@link org.hibernate.dialect.OracleDialect Oracle 23ai} and {@link org.hibernate.dialect.MariaDBDialect MariaDB}.
|
||||
* An embedding vector essentially is a {@code float[]} with a fixed size.
|
||||
*
|
||||
* @since 6.4
|
||||
|
|
|
@ -14,6 +14,7 @@ import jakarta.persistence.RollbackException;
|
|||
import org.hibernate.cfg.AvailableSettings;
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.testing.junit4.BaseNonConfigCoreFunctionalTestCase;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -69,7 +70,7 @@ public class BatchOptimisticLockingTest extends
|
|||
} );
|
||||
|
||||
try {
|
||||
inTransaction( (session) -> {
|
||||
inTransaction( session -> {
|
||||
List<Person> persons = session
|
||||
.createSelectionQuery( "select p from Person p", Person.class )
|
||||
.getResultList();
|
||||
|
@ -107,10 +108,19 @@ public class BatchOptimisticLockingTest extends
|
|||
}
|
||||
else {
|
||||
assertEquals( OptimisticLockException.class, expected.getClass() );
|
||||
assertTrue(
|
||||
expected.getMessage()
|
||||
.startsWith("Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]")
|
||||
);
|
||||
|
||||
if ( getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
|
||||
assertTrue(
|
||||
expected.getMessage()
|
||||
.contains( "Record has changed since last read in table 'Person'" )
|
||||
);
|
||||
} else {
|
||||
assertTrue(
|
||||
expected.getMessage()
|
||||
.startsWith(
|
||||
"Batch update returned unexpected row count from update 1 (expected row count 1 but was 0) [update Person set name=?,version=? where id=? and version=?]" )
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,8 @@ import jakarta.persistence.RollbackException;
|
|||
import org.hibernate.StaleObjectStateException;
|
||||
import org.hibernate.cfg.AvailableSettings;
|
||||
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.exception.LockAcquisitionException;
|
||||
import org.hibernate.exception.TransactionSerializationException;
|
||||
import org.hibernate.testing.orm.junit.JiraKey;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
|
@ -32,6 +34,7 @@ import jakarta.persistence.Table;
|
|||
import jakarta.persistence.Version;
|
||||
|
||||
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
|
||||
import static org.hibernate.testing.orm.junit.DialectContext.getDialect;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
|
@ -136,7 +139,12 @@ public class BatchUpdateAndVersionTest {
|
|||
fail();
|
||||
}
|
||||
catch (OptimisticLockException ole) {
|
||||
assertTrue( ole.getCause() instanceof StaleObjectStateException );
|
||||
if (getDialect() instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
|
||||
// if @@innodb_snapshot_isolation is set, database throw an exception if record is not available anymore
|
||||
assertTrue( ole.getCause() instanceof LockAcquisitionException );
|
||||
} else {
|
||||
assertTrue( ole.getCause() instanceof StaleObjectStateException );
|
||||
}
|
||||
}
|
||||
//CockroachDB errors with a Serialization Exception
|
||||
catch (RollbackException rbe) {
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.hibernate.StaleObjectStateException;
|
|||
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
|
||||
import org.hibernate.cfg.AvailableSettings;
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.dialect.SQLServerDialect;
|
||||
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
|
||||
import org.hibernate.exception.SQLGrammarException;
|
||||
|
@ -105,6 +106,7 @@ public class RepeatableReadTest extends AbstractJPATest {
|
|||
|
||||
@Test
|
||||
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row")
|
||||
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
|
||||
public void testStaleVersionedInstanceFoundOnLock() {
|
||||
if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) {
|
||||
return;
|
||||
|
@ -228,6 +230,7 @@ public class RepeatableReadTest extends AbstractJPATest {
|
|||
|
||||
@Test
|
||||
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Cockroach uses SERIALIZABLE by default and fails to acquire a write lock after a TX in between committed changes to a row")
|
||||
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
|
||||
public void testStaleNonVersionedInstanceFoundOnLock() {
|
||||
if ( !readCommittedIsolationMaintained( "repeatable read tests" ) ) {
|
||||
return;
|
||||
|
|
|
@ -10,7 +10,7 @@ import java.util.stream.Stream;
|
|||
|
||||
import org.hibernate.LockMode;
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.JiraKey;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
|
@ -30,6 +30,7 @@ import jakarta.persistence.Version;
|
|||
@SessionFactory
|
||||
@JiraKey("HHH-16461")
|
||||
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "CockroachDB uses SERIALIZABLE isolation, and does not support this")
|
||||
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
|
||||
public class OptimisticAndPessimisticLockTest {
|
||||
|
||||
public Stream<LockMode> pessimisticLockModes() {
|
||||
|
|
|
@ -10,8 +10,9 @@ import jakarta.persistence.Id;
|
|||
import jakarta.persistence.Version;
|
||||
import org.hibernate.annotations.OptimisticLock;
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.orm.test.jpa.BaseEntityManagerFunctionalTestCase;
|
||||
import org.hibernate.testing.SkipForDialect;
|
||||
import org.hibernate.testing.orm.junit.SkipForDialect;
|
||||
import org.junit.Test;
|
||||
|
||||
import static org.hibernate.testing.transaction.TransactionUtil.doInJPA;
|
||||
|
@ -29,7 +30,8 @@ public class OptimisticLockTest extends BaseEntityManagerFunctionalTestCase {
|
|||
}
|
||||
|
||||
@Test
|
||||
@SkipForDialect(value = CockroachDialect.class, comment = "Fails at SERIALIZABLE isolation")
|
||||
@SkipForDialect(dialectClass = CockroachDialect.class, reason = "Fails at SERIALIZABLE isolation")
|
||||
@SkipForDialect(dialectClass = MariaDBDialect.class, majorVersion = 11, minorVersion = 6, microVersion = 2, reason = "MariaDB will throw an error DB_RECORD_CHANGED when acquiring a lock on a record that have changed")
|
||||
public void test() {
|
||||
doInJPA(this::entityManagerFactory, entityManager -> {
|
||||
Phone phone = new Phone();
|
||||
|
|
|
@ -13,6 +13,7 @@ import org.hibernate.StaleObjectStateException;
|
|||
import org.hibernate.StaleStateException;
|
||||
import org.hibernate.dialect.CockroachDialect;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.dialect.SQLServerDialect;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DialectFeatureChecks;
|
||||
|
@ -23,6 +24,7 @@ import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
|||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.hibernate.testing.orm.junit.DialectContext.getDialect;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
/**
|
||||
|
@ -189,8 +191,9 @@ public class OptimisticLockTest {
|
|||
"40001" ) ) {
|
||||
// CockroachDB always runs in SERIALIZABLE isolation, and uses SQL state 40001 to indicate
|
||||
// serialization failure.
|
||||
}
|
||||
else {
|
||||
} else if (dialect instanceof MariaDBDialect && getDialect().getVersion().isAfter( 11, 6, 2 )) {
|
||||
// Mariadb snapshot_isolation throws error
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* SPDX-License-Identifier: LGPL-2.1-or-later
|
||||
* Copyright Red Hat Inc. and Hibernate Authors
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.ValueBinder;
|
||||
import org.hibernate.type.descriptor.ValueExtractor;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.BasicBinder;
|
||||
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
import java.sql.CallableStatement;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
|
||||
public class BinaryVectorJdbcType extends ArrayJdbcType {
|
||||
|
||||
public BinaryVectorJdbcType(JdbcType elementJdbcType) {
|
||||
super( elementJdbcType );
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
|
||||
Integer precision,
|
||||
Integer scale,
|
||||
TypeConfiguration typeConfiguration) {
|
||||
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( writeExpression );
|
||||
}
|
||||
|
||||
@Override
|
||||
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
|
||||
return new BasicExtractor<>( javaTypeDescriptor, this ) {
|
||||
@Override
|
||||
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( rs.getObject( paramIndex, float[].class ), options );
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( statement.getObject( index, float[].class ), options );
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( statement.getObject( name, float[].class ), options );
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
|
||||
return new BasicBinder<>( javaTypeDescriptor, this ) {
|
||||
|
||||
@Override
|
||||
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException {
|
||||
st.setObject( index, value );
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
|
||||
throws SQLException {
|
||||
st.setObject( name, value, java.sql.Types.ARRAY );
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getBindValue(X value, WrapperOptions options) {
|
||||
return value;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* SPDX-License-Identifier: LGPL-2.1-or-later
|
||||
* Copyright Red Hat Inc. and Hibernate Authors
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import org.hibernate.boot.model.FunctionContributions;
|
||||
import org.hibernate.boot.model.FunctionContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class MariaDBFunctionContributor implements FunctionContributor {
|
||||
|
||||
@Override
|
||||
public void contributeFunctions(FunctionContributions functionContributions) {
|
||||
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
|
||||
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
final Dialect dialect = functionContributions.getDialect();
|
||||
if ( dialect instanceof MariaDBDialect ) {
|
||||
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
|
||||
|
||||
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vec_distance_cosine(?1,?2)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vec_distance_euclidean(?1,?2)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordinal() {
|
||||
return 200;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* SPDX-License-Identifier: LGPL-2.1-or-later
|
||||
* Copyright Red Hat Inc. and Hibernate Authors
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import org.hibernate.boot.model.TypeContributions;
|
||||
import org.hibernate.boot.model.TypeContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.engine.jdbc.Size;
|
||||
import org.hibernate.engine.jdbc.spi.JdbcServices;
|
||||
import org.hibernate.service.ServiceRegistry;
|
||||
import org.hibernate.type.BasicArrayType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
|
||||
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
public class MariaDBTypeContributor implements TypeContributor {
|
||||
|
||||
private static final Type[] VECTOR_JAVA_TYPES = {
|
||||
Float[].class,
|
||||
float[].class
|
||||
};
|
||||
|
||||
@Override
|
||||
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
|
||||
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
|
||||
if ( dialect instanceof MariaDBDialect ) {
|
||||
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
|
||||
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
|
||||
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
|
||||
final ArrayJdbcType vectorJdbcType = new BinaryVectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) );
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType );
|
||||
for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) {
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
floatBasicType,
|
||||
vectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( vectorJavaType )
|
||||
),
|
||||
StandardBasicTypes.VECTOR.getName()
|
||||
);
|
||||
}
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return getTypeName(
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue(),
|
||||
null,
|
||||
null
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,2 +1,3 @@
|
|||
org.hibernate.vector.PGVectorFunctionContributor
|
||||
org.hibernate.vector.OracleVectorFunctionContributor
|
||||
org.hibernate.vector.MariaDBFunctionContributor
|
||||
|
|
|
@ -1,2 +1,3 @@
|
|||
org.hibernate.vector.PGVectorTypeContributor
|
||||
org.hibernate.vector.OracleVectorTypeContributor
|
||||
org.hibernate.vector.MariaDBTypeContributor
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* SPDX-License-Identifier: LGPL-2.1-or-later
|
||||
* Copyright Red Hat Inc. and Hibernate Authors
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.MariaDBDialect;
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author Diego Dupin
|
||||
*/
|
||||
@DomainModel(annotatedClasses = MariaDBTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = MariaDBDialect.class, matchSubTypes = false, majorVersion = 11, minorVersion = 7)
|
||||
public class MariaDBTest {
|
||||
|
||||
private static final float[] V1 = new float[]{ 1, 2, 3 };
|
||||
private static final float[] V2 = new float[]{ 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
private static double cosineDistance(float[] f1, float[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double innerProduct(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static double euclideanNorm(float[] f) {
|
||||
double result = 0;
|
||||
for ( float v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
@Entity( name = "VectorEntity" )
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column( name = "the_vector" )
|
||||
@JdbcTypeCode(SqlTypes.VECTOR)
|
||||
@Array(length = 3)
|
||||
private float[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, float[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public float[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(float[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue