From 60b0a636299309616c8b1945ea5c722220319736 Mon Sep 17 00:00:00 2001 From: Hassan AL Meftah Date: Thu, 18 Apr 2024 16:08:11 +0100 Subject: [PATCH] HHH-17738 : Add support for Oracle database AI Vector Search --- .../org/hibernate/dialect/OracleDialect.java | 17 + .../dialect/OracleServerConfiguration.java | 44 ++- .../org/hibernate/dialect/OracleTypes.java | 5 + .../java/org/hibernate/type/SqlTypes.java | 23 +- .../hibernate/type/StandardBasicTypes.java | 48 ++- .../vector/AbstractOracleVectorJdbcType.java | 155 +++++++++ .../vector/OracleByteVectorJdbcType.java | 93 ++++++ .../vector/OracleDoubleVectorJdbcType.java | 93 ++++++ .../vector/OracleFloatVectorJdbcType.java | 93 ++++++ .../OracleVectorFunctionContributor.java | 107 +++++++ .../vector/OracleVectorJdbcType.java | 51 +++ .../vector/OracleVectorTypeContributor.java | 167 ++++++++++ .../vector/VectorArgumentValidator.java | 23 +- ...g.hibernate.boot.model.FunctionContributor | 3 +- .../org.hibernate.boot.model.TypeContributor | 3 +- .../vector/OracleByteVectorTest.java | 280 ++++++++++++++++ .../vector/OracleDoubleVectorTest.java | 277 ++++++++++++++++ .../vector/OracleFloatVectorTest.java | 301 +++++++++++++++++ .../vector/OracleGenericVectorTest.java | 302 ++++++++++++++++++ 19 files changed, 2074 insertions(+), 11 deletions(-) create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java create mode 100644 hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java create mode 100644 hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java create mode 100644 hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java create mode 100644 hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java create mode 100644 hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleDialect.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleDialect.java index 8b22eec2c8..31fad0b874 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleDialect.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleDialect.java @@ -185,6 +185,11 @@ public class OracleDialect extends Dialect { // Is MAX_STRING_SIZE set to EXTENDED? protected final boolean extended; + protected final int driverMajorVersion; + + protected final int driverMinorVersion; + + public OracleDialect() { this( MINIMUM_VERSION ); } @@ -193,6 +198,8 @@ public class OracleDialect extends Dialect { super(version); autonomous = false; extended = false; + driverMajorVersion = 19; + driverMinorVersion = 0; } public OracleDialect(DialectResolutionInfo info) { @@ -203,6 +210,8 @@ public class OracleDialect extends Dialect { super( info ); autonomous = serverConfiguration.isAutonomous(); extended = serverConfiguration.isExtended(); + this.driverMinorVersion = serverConfiguration.getDriverMinorVersion(); + this.driverMajorVersion = serverConfiguration.getDriverMajorVersion(); } @Deprecated( since = "6.4" ) @@ -1621,4 +1630,12 @@ public class OracleDialect extends Dialect { public boolean supportsFromClauseInUpdate() { return true; } + + public int getDriverMajorVersion() { + return driverMajorVersion; + } + + public int getDriverMinorVersion() { + return driverMinorVersion; + } } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleServerConfiguration.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleServerConfiguration.java index e6005ce8fe..ee7c104ac1 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleServerConfiguration.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleServerConfiguration.java @@ -26,6 +26,8 @@ import static org.hibernate.cfg.DialectSpecificSettings.ORACLE_EXTENDED_STRING_S public class OracleServerConfiguration { private final boolean autonomous; private final boolean extended; + private final int driverMajorVersion; + private final int driverMinorVersion; public boolean isAutonomous() { return autonomous; @@ -35,16 +37,39 @@ public class OracleServerConfiguration { return extended; } + public int getDriverMajorVersion() { + return driverMajorVersion; + } + + public int getDriverMinorVersion() { + return driverMinorVersion; + } + public OracleServerConfiguration(boolean autonomous, boolean extended) { + this( autonomous, extended, 19, 0 ); + } + + public OracleServerConfiguration( + boolean autonomous, + boolean extended, + int driverMajorVersion, + int driverMinorVersion) { this.autonomous = autonomous; this.extended = extended; + this.driverMajorVersion = driverMajorVersion; + this.driverMinorVersion = driverMinorVersion; } public static OracleServerConfiguration fromDialectResolutionInfo(DialectResolutionInfo info) { Boolean extended = null; Boolean autonomous = null; + Integer majorVersion = null; + Integer minorVersion = null; final DatabaseMetaData databaseMetaData = info.getDatabaseMetadata(); if ( databaseMetaData != null ) { + majorVersion = databaseMetaData.getDriverMajorVersion(); + minorVersion = databaseMetaData.getDriverMinorVersion(); + try (final Statement statement = databaseMetaData.getConnection().createStatement()) { final ResultSet rs = statement.executeQuery( "select cast('string' as varchar2(32000)), " + @@ -77,7 +102,19 @@ public class OracleServerConfiguration { false ); } - return new OracleServerConfiguration( autonomous, extended ); + if ( majorVersion == null ) { + try { + java.sql.Driver driver = java.sql.DriverManager.getDriver( "jdbc:oracle:thin:" ); + majorVersion = driver.getMajorVersion(); + minorVersion = driver.getMinorVersion(); + } + catch (SQLException ex) { + majorVersion = 19; + minorVersion = 0; + } + + } + return new OracleServerConfiguration( autonomous, extended, majorVersion, minorVersion ); } private static boolean isAutonomous(String cloudServiceParam) { @@ -86,11 +123,14 @@ public class OracleServerConfiguration { private static boolean isAutonomous(DatabaseMetaData databaseMetaData) { try (final Statement statement = databaseMetaData.getConnection().createStatement()) { - return statement.executeQuery( "select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" ).next(); + return statement.executeQuery( + "select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" ) + .next(); } catch (SQLException ex) { // Ignore } return false; } + } diff --git a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java index bb0b3d7d32..9409036853 100644 --- a/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/dialect/OracleTypes.java @@ -12,4 +12,9 @@ package org.hibernate.dialect; public class OracleTypes { public static final int CURSOR = -10; public static final int JSON = 2016; + + public static final int VECTOR = -105; + public static final int VECTOR_INT8 = -106; + public static final int VECTOR_FLOAT32 = -107; + public static final int VECTOR_FLOAT64 = -108; } diff --git a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java index 06fe65943b..09aeeeeabe 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/SqlTypes.java @@ -211,7 +211,6 @@ public class SqlTypes { * as a synonym for {@link #VARBINARY}. * * @see org.hibernate.Length#LONG - * * @see Types#LONGVARBINARY * @see org.hibernate.type.descriptor.jdbc.LongVarbinaryJdbcType */ @@ -356,7 +355,6 @@ public class SqlTypes { * as a synonym for {@link #NVARCHAR}. * * @see org.hibernate.Length#LONG - * * @see Types#LONGNVARCHAR * @see org.hibernate.type.descriptor.jdbc.LongNVarcharJdbcType */ @@ -657,13 +655,28 @@ public class SqlTypes { /** * A type code representing an {@code embedding vector} type for databases like - * {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} that have special extensions. + * {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and {@link org.hibernate.dialect.OracleDialect Oracle 23ai}. * An embedding vector essentially is a {@code float[]} with a fixed size. * * @since 6.4 */ public static final int VECTOR = 10_000; + /** + * A type code representing a single-byte integer vector type for oracle 23ai database. + */ + public static final int VECTOR_INT8 = 10_001; + + /** + * A type code representing a single-precision floating-point vector type for oracle 23ai database. + */ + public static final int VECTOR_FLOAT32 = 10_002; + + /** + * A type code representing a double-precision floating-point type for oracle 23ai database. + */ + public static final int VECTOR_FLOAT64 = 10_003; + private SqlTypes() { } @@ -693,6 +706,7 @@ public class SqlTypes { /** * Is this a type with a length, that is, is it * some kind of character string or binary string? + * * @param typeCode a JDBC type code from {@link Types} */ public static boolean isStringType(int typeCode) { @@ -715,6 +729,7 @@ public class SqlTypes { /** * Does the given JDBC type code represent some sort of * character string type? + * * @param typeCode a JDBC type code from {@link Types} */ public static boolean isCharacterOrClobType(int typeCode) { @@ -736,6 +751,7 @@ public class SqlTypes { /** * Does the given JDBC type code represent some sort of * character string type? + * * @param typeCode a JDBC type code from {@link Types} */ public static boolean isCharacterType(int typeCode) { @@ -755,6 +771,7 @@ public class SqlTypes { /** * Does the given JDBC type code represent some sort of * variable-length character string type? + * * @param typeCode a JDBC type code from {@link Types} */ public static boolean isVarcharType(int typeCode) { diff --git a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java index 15330f3fa4..9bccc16c35 100644 --- a/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java +++ b/hibernate-core/src/main/java/org/hibernate/type/StandardBasicTypes.java @@ -551,7 +551,6 @@ public final class StandardBasicTypes { ); - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Binary mappings @@ -746,12 +745,36 @@ public final class StandardBasicTypes { /** * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, - * specifically for embedding vectors like provided by the PostgreSQL extension pgvector. + * specifically for embedding vectors like provided by the PostgreSQL extension pgvector and Oracle 23ai. */ public static final BasicTypeReference VECTOR = new BasicTypeReference<>( "vector", float[].class, SqlTypes.VECTOR ); + /** + * The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8}, + * specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai. + */ + public static final BasicTypeReference VECTOR_INT8 = new BasicTypeReference<>( + "byte_vector", byte[].class, SqlTypes.VECTOR_INT8 + ); + + /** + * The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, + * specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai. + */ + public static final BasicTypeReference VECTOR_FLOAT32 = new BasicTypeReference<>( + "float_vector", float[].class, SqlTypes.VECTOR_FLOAT32 + ); + + /** + * The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR}, + * specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai. + */ + public static final BasicTypeReference VECTOR_FLOAT64 = new BasicTypeReference<>( + "double_vector", double[].class, SqlTypes.VECTOR_FLOAT64 + ); + public static void prime(TypeConfiguration typeConfiguration) { BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); @@ -1252,6 +1275,27 @@ public final class StandardBasicTypes { "vector" ); + handle( + VECTOR_FLOAT32, + null, + basicTypeRegistry, + "float_vector" + ); + + handle( + VECTOR_FLOAT64, + null, + basicTypeRegistry, + "double_vector" + ); + + handle( + VECTOR_INT8, + null, + basicTypeRegistry, + "byte_vector" + ); + // Specialized version handlers diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java new file mode 100644 index 0000000000..f73cb049dd --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java @@ -0,0 +1,155 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.sql.CallableStatement; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.hibernate.dialect.Dialect; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.ValueBinder; +import org.hibernate.type.descriptor.ValueExtractor; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.BasicPluralJavaType; +import org.hibernate.type.descriptor.java.ByteJavaType; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType; +import org.hibernate.type.descriptor.jdbc.ArrayJdbcType; +import org.hibernate.type.descriptor.jdbc.BasicBinder; +import org.hibernate.type.descriptor.jdbc.BasicExtractor; +import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.internal.JdbcLiteralFormatterArray; + +/** + * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle. + *

+ * This class handles generic vectors represented by an asterisk (*) in the format, + * allowing for different element types within the vector. + * + * @author Hassan AL Meftah + */ +public abstract class AbstractOracleVectorJdbcType extends ArrayJdbcType { + + final boolean isVectorSupported; + + public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType ); + this.isVectorSupported = isVectorSupported; + } + + public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect); + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR; + } + + @Override + public JdbcLiteralFormatter getJdbcLiteralFormatter(JavaType javaTypeDescriptor) { + final JavaType elementJavaType; + if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) { + // Special handling needed for Byte[], because that would conflict with the VARBINARY mapping + //noinspection unchecked + elementJavaType = (JavaType) ByteJavaType.INSTANCE; + } + else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) { + //noinspection unchecked + elementJavaType = ( (BasicPluralJavaType) javaTypeDescriptor ).getElementJavaType(); + } + else { + throw new IllegalArgumentException( "not a BasicPluralJavaType" ); + } + return new JdbcLiteralFormatterArray<>( + javaTypeDescriptor, + getElementJdbcType().getJdbcLiteralFormatter( elementJavaType ) + ); + } + + @Override + public String toString() { + return "OracleVectorTypeDescriptor"; + } + + + @Override + public ValueBinder getBinder(final JavaType javaTypeDescriptor) { + return new BasicBinder<>( javaTypeDescriptor, this ) { + @Override + protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( index, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() ); + } + else { + st.setString( index, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) ); + } + } + + @Override + protected void doBind(CallableStatement st, X value, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + st.setObject( name, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() ); + } + else { + st.setString( name, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) ); + } + } + + }; + } + + @Override + public ValueExtractor getExtractor(final JavaType javaTypeDescriptor) { + return new BasicExtractor<>( javaTypeDescriptor, this ) { + @Override + protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( rs.getObject( paramIndex, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options ); + } + else { + return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( rs.getString( paramIndex ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( statement.getObject( index, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options ); + } + else { + return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( index ) ), options ); + } + } + + @Override + protected X doExtract(CallableStatement statement, String name, WrapperOptions options) + throws SQLException { + if ( isVectorSupported ) { + return getJavaType().wrap( statement.getObject( name, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options ); + } + else { + return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( name ) ), options ); + } + } + + }; + } + + protected abstract T getVectorArray(String string); + + protected abstract String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options); + + protected abstract Class getNativeJavaType(); + + protected abstract int getNativeTypeCode(); + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java new file mode 100644 index 0000000000..ed11b4dd2d --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleByteVectorJdbcType.java @@ -0,0 +1,93 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.BitSet; + +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleTypes; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +/** + * Specialized type mapping for single-byte integer vector {@link SqlTypes#VECTOR_INT8} SQL data type for Oracle. + * + * @author Hassan AL Meftah + */ +public class OracleByteVectorJdbcType extends AbstractOracleVectorJdbcType { + + + private static final byte[] EMPTY = new byte[0]; + + public OracleByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ", *, INT8)" ); + } + + @Override + public String getFriendlyName() { + return "VECTOR_INT8"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_INT8; + } + + @Override + protected byte[] getVectorArray(String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final byte[] result = new byte[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) ); + } + + protected Class getNativeJavaType(){ + return byte[].class; + }; + + protected int getNativeTypeCode(){ + return OracleTypes.VECTOR_INT8; + }; + +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java new file mode 100644 index 0000000000..d32aa596dd --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleDoubleVectorJdbcType.java @@ -0,0 +1,93 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.BitSet; + +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleTypes; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +/** + * Specialized type mapping for double-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT64} SQL data type for Oracle. + * + * @author Hassan AL Meftah + */ +public class OracleDoubleVectorJdbcType extends AbstractOracleVectorJdbcType { + + private static final double[] EMPTY = new double[0]; + + public OracleDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ", *, FLOAT64)" ); + } + + @Override + public String getFriendlyName() { + return "VECTOR_FLOAT64"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_FLOAT64; + } + + @Override + protected double[] getVectorArray(String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final double[] result = new double[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, double[].class, options ) ); + } + + protected Class getNativeJavaType() { + return double[].class; + } + + @Override + protected int getNativeTypeCode() { + return OracleTypes.VECTOR_FLOAT64; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java new file mode 100644 index 0000000000..17e81b5967 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleFloatVectorJdbcType.java @@ -0,0 +1,93 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.Arrays; +import java.util.BitSet; + +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleTypes; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.WrapperOptions; +import org.hibernate.type.descriptor.java.JavaType; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +/** + * Specialized type mapping for single-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT32} SQL data type for Oracle. + * + * @author Hassan AL Meftah + */ + +public class OracleFloatVectorJdbcType extends AbstractOracleVectorJdbcType { + + + private static final float[] EMPTY = new float[0]; + + public OracleFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ", *, FLOAT32)" ); + } + + @Override + public String getFriendlyName() { + return "VECTOR_FLOAT32"; + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR_FLOAT32; + } + + @Override + protected float[] getVectorArray(String string) { + if ( string == null ) { + return null; + } + if ( string.length() == 2 ) { + return EMPTY; + } + final BitSet commaPositions = new BitSet(); + int size = 1; + for ( int i = 1; i < string.length(); i++ ) { + final char c = string.charAt( i ); + if ( c == ',' ) { + commaPositions.set( i ); + size++; + } + } + final float[] result = new float[size]; + int doubleStartIndex = 1; + int commaIndex; + int index = 0; + while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) { + result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) ); + doubleStartIndex = commaIndex + 1; + } + result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) ); + return result; + } + + @Override + protected String getStringVector(T vector, JavaType javaTypeDescriptor, WrapperOptions options) { + return Arrays.toString( javaTypeDescriptor.unwrap( vector, float[].class, options ) ); + } + + protected Class getNativeJavaType() { + return float[].class; + } + + protected int getNativeTypeCode() { + return OracleTypes.VECTOR_FLOAT32; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java new file mode 100644 index 0000000000..0b8aa0d315 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorFunctionContributor.java @@ -0,0 +1,107 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import org.hibernate.boot.model.FunctionContributions; +import org.hibernate.boot.model.FunctionContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.query.sqm.function.SqmFunctionRegistry; +import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators; +import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.spi.TypeConfiguration; + +public class OracleVectorFunctionContributor implements FunctionContributor { + + @Override + public void contributeFunctions(FunctionContributions functionContributions) { + final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry(); + final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + final Dialect dialect = functionContributions.getDialect(); + if ( dialect instanceof OracleDialect ) { + final BasicType doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ); + final BasicType integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER ); + functionRegistry.patternDescriptorBuilder( "cosine_distance", "vector_distance(?1, ?2, COSINE)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" ); + + functionRegistry.patternDescriptorBuilder( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)") + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" ); + + functionRegistry.patternDescriptorBuilder( "negative_inner_product", "vector_distance(?1, ?2, DOT)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.patternDescriptorBuilder( "inner_product", "vector_distance(?1, ?2, DOT)*-1" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.patternDescriptorBuilder( "hamming_distance", "vector_distance(?1, ?2, HAMMING)" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 2 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + functionRegistry.namedDescriptorBuilder( "vector_dims" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) ) + .register(); + functionRegistry.namedDescriptorBuilder( "vector_norm" ) + .setArgumentsValidator( StandardArgumentsValidators.composite( + StandardArgumentsValidators.exactly( 1 ), + VectorArgumentValidator.INSTANCE + ) ) + .setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE ) + .setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) ) + .register(); + } + } + + @Override + public int ordinal() { + return 200; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java new file mode 100644 index 0000000000..649626ee99 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorJdbcType.java @@ -0,0 +1,51 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleTypes; +import org.hibernate.sql.ast.spi.SqlAppender; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.descriptor.jdbc.JdbcType; + +/** + * Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle. + *

+ * This class handles generic vectors represented by an asterisk (*) in the format, + * allowing for different element types within the vector. + * + * @author Hassan AL Meftah + */ +public class OracleVectorJdbcType extends OracleFloatVectorJdbcType { + + + public OracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) { + super( elementJdbcType, isVectorSupported ); + } + + @Override + public String getFriendlyName() { + return "VECTOR"; + } + + @Override + public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) { + appender.append( "to_vector(" ); + appender.append( writeExpression ); + appender.append( ", *, *)" ); + } + + @Override + public int getDefaultSqlTypeCode() { + return SqlTypes.VECTOR; + } + + @Override + protected int getNativeTypeCode() { + return OracleTypes.VECTOR; + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java new file mode 100644 index 0000000000..994f70a142 --- /dev/null +++ b/hibernate-vector/src/main/java/org/hibernate/vector/OracleVectorTypeContributor.java @@ -0,0 +1,167 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.sql.SQLException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.hibernate.boot.model.TypeContributions; +import org.hibernate.boot.model.TypeContributor; +import org.hibernate.dialect.Dialect; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.engine.jdbc.Size; +import org.hibernate.engine.jdbc.spi.JdbcServices; +import org.hibernate.internal.util.StringHelper; +import org.hibernate.service.ServiceRegistry; +import org.hibernate.type.BasicArrayType; +import org.hibernate.type.BasicType; +import org.hibernate.type.BasicTypeReference; +import org.hibernate.type.BasicTypeRegistry; +import org.hibernate.type.SqlTypes; +import org.hibernate.type.StandardBasicTypes; +import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry; +import org.hibernate.type.descriptor.jdbc.JdbcType; +import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry; +import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl; +import org.hibernate.type.spi.TypeConfiguration; + +public class OracleVectorTypeContributor implements TypeContributor { + + + @Override + public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) { + final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ) + .getDialect(); + + if ( dialect instanceof OracleDialect && dialect.getVersion().isSameOrAfter( 23 ) ) { + final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration(); + final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry(); + final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry(); + final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry(); + + final boolean isVectorSupported = isVectorSupportedByDriver( (OracleDialect) dialect ); + + // Register generic vector type + final OracleVectorJdbcType genericVectorJdbcType = new OracleVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType ); + final JdbcType floatVectorJdbcType = new OracleFloatVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType ); + final JdbcType doubleVectorJdbcType = new OracleDoubleVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.DOUBLE ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT64, doubleVectorJdbcType ); + final JdbcType byteVectorJdbcType = new OracleByteVectorJdbcType( + jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ), + isVectorSupported + ); + jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType ); + + + // Resolving basic types after jdbc types are registered. + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + genericVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ), + floatVectorJdbcType, + javaTypeRegistry.getDescriptor( float[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT32.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ), + doubleVectorJdbcType, + javaTypeRegistry.getDescriptor( double[].class ) + ), + StandardBasicTypes.VECTOR_FLOAT64.getName() + ); + basicTypeRegistry.register( + new BasicArrayType<>( + basicTypeRegistry.resolve( StandardBasicTypes.BYTE ), + byteVectorJdbcType, + javaTypeRegistry.getDescriptor( byte[].class ) + ), + StandardBasicTypes.VECTOR_INT8.getName() + ); + + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new DdlTypeImpl( SqlTypes.VECTOR, "vector($l, *)", "vector", dialect ) { + @Override + public String getTypeName(Size size) { + return OracleVectorTypeContributor.replace( + "vector($l, *)", + size.getArrayLength() == null ? null : size.getArrayLength().longValue() + ); + } + } + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new DdlTypeImpl( SqlTypes.VECTOR_INT8, "vector($l, INT8)", "vector", dialect ) { + @Override + public String getTypeName(Size size) { + return OracleVectorTypeContributor.replace( + "vector($l, INT8)", + size.getArrayLength() == null ? null : size.getArrayLength().longValue() + ); + } + } + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new DdlTypeImpl( SqlTypes.VECTOR_FLOAT32, "vector($l, FLOAT32)", "vector", dialect ) { + @Override + public String getTypeName(Size size) { + return OracleVectorTypeContributor.replace( + "vector($l, FLOAT32)", + size.getArrayLength() == null ? null : size.getArrayLength().longValue() + ); + } + } + ); + typeConfiguration.getDdlTypeRegistry().addDescriptor( + new DdlTypeImpl( SqlTypes.VECTOR_FLOAT64, "vector($l, FLOAT64)", "vector", dialect ) { + @Override + public String getTypeName(Size size) { + return OracleVectorTypeContributor.replace( + "vector($l, FLOAT64)", + size.getArrayLength() == null ? null : size.getArrayLength().longValue() + ); + } + } + ); + } + } + + + /** + * Replace vector dimension with the length or * for undefined length + */ + private static String replace(String type, Long size) { + return StringHelper.replaceOnce( type, "$l", size != null ? size.toString() : "*" ); + } + + private boolean isVectorSupportedByDriver(OracleDialect dialect) { + int majorVersion = dialect.getDriverMajorVersion(); + int minorVersion = dialect.getDriverMinorVersion(); + + return ( majorVersion > 23 ) || ( majorVersion == 23 && minorVersion >= 4 ); + } +} diff --git a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java b/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java index 8e4766ef36..826d6cd3f5 100644 --- a/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java +++ b/hibernate-vector/src/main/java/org/hibernate/vector/VectorArgumentValidator.java @@ -14,6 +14,7 @@ import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.produce.function.FunctionArgumentException; import org.hibernate.query.sqm.tree.SqmTypedNode; import org.hibernate.type.BasicPluralType; +import org.hibernate.type.BasicType; import org.hibernate.type.SqlTypes; import org.hibernate.type.spi.TypeConfiguration; @@ -24,6 +25,13 @@ public class VectorArgumentValidator implements ArgumentsValidator { public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator(); + private static final int[] availableVectorCodes = { + SqlTypes.VECTOR, + SqlTypes.VECTOR_INT8, + SqlTypes.VECTOR_FLOAT32, + SqlTypes.VECTOR_FLOAT64 + }; + @Override public void validate( List> arguments, @@ -46,7 +54,18 @@ public class VectorArgumentValidator implements ArgumentsValidator { } private static boolean isVectorType(SqmExpressible vectorType) { - return vectorType instanceof BasicPluralType - && ( (BasicPluralType) vectorType ).getJdbcType().getDefaultSqlTypeCode() == SqlTypes.VECTOR; + if ( !( vectorType instanceof BasicPluralType ) ) { + return false; + } + + switch ( ( (BasicType) vectorType ).getJdbcType().getDefaultSqlTypeCode() ) { + case SqlTypes.VECTOR: + case SqlTypes.VECTOR_INT8: + case SqlTypes.VECTOR_FLOAT32: + case SqlTypes.VECTOR_FLOAT64: + return true; + default: + return false; + } } } diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor index 50860d311b..477fddb117 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor @@ -1 +1,2 @@ -org.hibernate.vector.PGVectorFunctionContributor \ No newline at end of file +org.hibernate.vector.PGVectorFunctionContributor +org.hibernate.vector.OracleVectorFunctionContributor \ No newline at end of file diff --git a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor index 51a46ab722..eeaa217a75 100644 --- a/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor +++ b/hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor @@ -1 +1,2 @@ -org.hibernate.vector.PGVectorTypeContributor \ No newline at end of file +org.hibernate.vector.PGVectorTypeContributor +org.hibernate.vector.OracleVectorTypeContributor \ No newline at end of file diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java new file mode 100644 index 0000000000..e26cc56f66 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/OracleByteVectorTest.java @@ -0,0 +1,280 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.List; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.type.SqlTypes; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author Hassan AL Meftah + */ +@DomainModel(annotatedClasses = OracleByteVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) +public class OracleByteVectorTest { + + private static final byte[] V1 = new byte[]{ 1, 2, 3 }; + private static final byte[] V2 = new byte[]{ 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new byte[]{ 1, 2, 3 }, tableRecord.getTheVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new byte[]{ 4, 5, 6 }, tableRecord.getTheVector() ); + } ); + } + + @Test + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::cosine-distance-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector, byte[].class ) + .getResultList(); + //end::cosine-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-distance-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::taxicab-distance-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::taxicab-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final byte[] vector = new byte[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-dims-example[] + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + //end::vector-dims-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-norm-example[] + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + //end::vector-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + private static double cosineDistance(byte[] f1, byte[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + private static double euclideanDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + private static double taxicabDistance(byte[] f1, byte[] f2) { + return norm( f1 ) - norm( f2 ); + } + + private static double innerProduct(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double hammingDistance(byte[] f1, byte[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + + private static double euclideanNorm(byte[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + private static double norm(byte[] f) { + double result = 0; + for ( double v : f ) { + result += Math.abs( v ); + } + return result; + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + //tag::usage-example[] + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.VECTOR_INT8) + @Array(length = 3) + private byte[] theVector; + //end::usage-example[] + + + + public VectorEntity() { + } + + public VectorEntity(Long id, byte[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public byte[] getTheVector() { + return theVector; + } + + public void setTheVector(byte[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java new file mode 100644 index 0000000000..ef482deac1 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/OracleDoubleVectorTest.java @@ -0,0 +1,277 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.List; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.type.SqlTypes; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author Hassan AL Meftah + */ +@DomainModel(annotatedClasses = OracleDoubleVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) +public class OracleDoubleVectorTest { + + private static final double[] V1 = new double[]{ 1, 2, 3 }; + private static final double[] V2 = new double[]{ 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new double[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new double[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Test + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::cosine-distance-example[] + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::cosine-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-distance-example[] + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.00002D); + } ); + } + + @Test + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::taxicab-distance-example[] + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::taxicab-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final double[] vector = new double[]{ 1, 1, 1 }; + final List results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-dims-example[] + final List results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + //end::vector-dims-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-norm-example[] + final List results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class ) + .getResultList(); + //end::vector-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + + private static double cosineDistance(double[] f1, double[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + private static double euclideanDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + private static double taxicabDistance(double[] f1, double[] f2) { + return norm( f1 ) - norm( f2 ); + } + + public static double hammingDistance(double[] f1, double[] f2) { + assert f1.length == f2.length; + int distance = 0; + for (int i = 0; i < f1.length; i++) { + if (!(f1[i] == f2[i])) { + distance++; + } + } + return distance; + } + + private static double innerProduct(double[] f1, double[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + private static double euclideanNorm(double[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + private static double norm(double[] f) { + double result = 0; + for ( double v : f ) { + result += Math.abs( v ); + } + return result; + } + + @Entity( name = "VectorEntity" ) + public static class VectorEntity { + + @Id + private Long id; + + //tag::usage-example[] + @Column( name = "the_vector" ) + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT64) + @Array(length = 3) + private double[] theVector; + //end::usage-example[] + + public VectorEntity() { + } + + public VectorEntity(Long id, double[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public double[] getTheVector() { + return theVector; + } + + public void setTheVector(double[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java new file mode 100644 index 0000000000..c813479bd8 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/OracleFloatVectorTest.java @@ -0,0 +1,301 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.List; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.type.SqlTypes; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author Hassan AL Meftah + */ +@DomainModel(annotatedClasses = OracleFloatVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) +public class OracleFloatVectorTest { + + private static final float[] V1 = new float[] { 1, 2, 3 }; + private static final float[] V2 = new float[] { 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 ); + } ); + } + + @Test + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::cosine-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::cosine-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::taxicab-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::taxicab-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-dims-example[] + final List results = em.createSelectionQuery( + "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::vector-dims-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-norm-example[] + final List results = em.createSelectionQuery( + "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::vector-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + + private static double cosineDistance(float[] f1, float[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + private static double euclideanDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + private static double taxicabDistance(float[] f1, float[] f2) { + return norm( f1 ) - norm( f2 ); + } + + private static double innerProduct(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double hammingDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + int distance = 0; + for ( int i = 0; i < f1.length; i++ ) { + if ( !( f1[i] == f2[i] ) ) { + distance++; + } + } + return distance; + } + + + private static double euclideanNorm(float[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + private static double norm(float[] f) { + double result = 0; + for ( double v : f ) { + result += Math.abs( v ); + } + return result; + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + //tag::usage-example[] + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR_FLOAT32) + @Array(length = 3) + private float[] theVector; + //end::usage-example[] + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +} diff --git a/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java b/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java new file mode 100644 index 0000000000..3641398c48 --- /dev/null +++ b/hibernate-vector/src/test/java/org/hibernate/vector/OracleGenericVectorTest.java @@ -0,0 +1,302 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * License: GNU Lesser General Public License (LGPL), version 2.1 or later. + * See the lgpl.txt file in the root directory or . + */ +package org.hibernate.vector; + +import java.util.List; + +import org.hibernate.annotations.Array; +import org.hibernate.annotations.JdbcTypeCode; +import org.hibernate.dialect.OracleDialect; +import org.hibernate.type.SqlTypes; + +import org.hibernate.testing.orm.junit.DomainModel; +import org.hibernate.testing.orm.junit.RequiresDialect; +import org.hibernate.testing.orm.junit.SessionFactory; +import org.hibernate.testing.orm.junit.SessionFactoryScope; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.Id; +import jakarta.persistence.Tuple; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +/** + * @author Hassan AL Meftah + */ +@DomainModel(annotatedClasses = OracleGenericVectorTest.VectorEntity.class) +@SessionFactory +@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4) +public class OracleGenericVectorTest { + + private static final float[] V1 = new float[] { 1, 2, 3 }; + private static final float[] V2 = new float[] { 4, 5, 6 }; + + @BeforeEach + public void prepareData(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.persist( new VectorEntity( 1L, V1 ) ); + em.persist( new VectorEntity( 2L, V2 ) ); + } ); + } + + @AfterEach + public void cleanup(SessionFactoryScope scope) { + scope.inTransaction( em -> { + em.createMutationQuery( "delete from VectorEntity" ).executeUpdate(); + } ); + } + + @Test + public void testRead(SessionFactoryScope scope) { + scope.inTransaction( em -> { + VectorEntity tableRecord; + tableRecord = em.find( VectorEntity.class, 1L ); + assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector() ); + + tableRecord = em.find( VectorEntity.class, 2L ); + assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector() ); + } ); + } + + @Test + public void testCosineDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::cosine-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::cosine-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D ); + } ); + } + + @Test + public void testEuclideanDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::euclidean-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::euclidean-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D ); + } ); + } + + @Test + public void testTaxicabDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::taxicab-distance-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::taxicab-distance-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testInnerProduct(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D ); + } ); + } + + @Test + public void testHammingDistance(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::inner-product-example[] + final float[] vector = new float[] { 1, 1, 1 }; + final List results = em.createSelectionQuery( + "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", + Tuple.class + ) + .setParameter( "vec", vector ) + .getResultList(); + //end::inner-product-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + @Test + public void testVectorDims(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-dims-example[] + final List results = em.createSelectionQuery( + "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::vector-dims-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( V1.length, results.get( 0 ).get( 1 ) ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( V2.length, results.get( 1 ).get( 1 ) ); + } ); + } + + @Test + public void testVectorNorm(SessionFactoryScope scope) { + scope.inTransaction( em -> { + //tag::vector-norm-example[] + final List results = em.createSelectionQuery( + "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", + Tuple.class + ) + .getResultList(); + //end::vector-norm-example[] + assertEquals( 2, results.size() ); + assertEquals( 1L, results.get( 0 ).get( 0 ) ); + assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D ); + assertEquals( 2L, results.get( 1 ).get( 0 ) ); + assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D ); + } ); + } + + + private static double cosineDistance(float[] f1, float[] f2) { + return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) ); + } + + private static double euclideanDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += Math.pow( (double) f1[i] - f2[i], 2 ); + } + return Math.sqrt( result ); + } + + private static double taxicabDistance(float[] f1, float[] f2) { + return norm( f1 ) - norm( f2 ); + } + + private static double innerProduct(float[] f1, float[] f2) { + assert f1.length == f2.length; + double result = 0; + for ( int i = 0; i < f1.length; i++ ) { + result += ( (double) f1[i] ) * ( (double) f2[i] ); + } + return result; + } + + public static double hammingDistance(float[] f1, float[] f2) { + assert f1.length == f2.length; + int distance = 0; + for ( int i = 0; i < f1.length; i++ ) { + if ( !( f1[i] == f2[i] ) ) { + distance++; + } + } + return distance; + } + + + private static double euclideanNorm(float[] f) { + double result = 0; + for ( double v : f ) { + result += Math.pow( v, 2 ); + } + return Math.sqrt( result ); + } + + private static double norm(float[] f) { + double result = 0; + for ( double v : f ) { + result += Math.abs( v ); + } + return result; + } + + @Entity(name = "VectorEntity") + public static class VectorEntity { + + @Id + private Long id; + + //tag::usage-example[] + @Column(name = "the_vector") + @JdbcTypeCode(SqlTypes.VECTOR) + @Array(length = 3) + private float[] theVector; + //end::usage-example[] + + + public VectorEntity() { + } + + public VectorEntity(Long id, float[] theVector) { + this.id = id; + this.theVector = theVector; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public float[] getTheVector() { + return theVector; + } + + public void setTheVector(float[] theVector) { + this.theVector = theVector; + } + } +}