HHH-17738 : Add support for Oracle database AI Vector Search

This commit is contained in:
Hassan AL Meftah 2024-04-18 16:08:11 +01:00 committed by Christian Beikov
parent 4791b41cf5
commit 60b0a63629
19 changed files with 2074 additions and 11 deletions

View File

@ -185,6 +185,11 @@ public class OracleDialect extends Dialect {
// Is MAX_STRING_SIZE set to EXTENDED?
protected final boolean extended;
protected final int driverMajorVersion;
protected final int driverMinorVersion;
public OracleDialect() {
this( MINIMUM_VERSION );
}
@ -193,6 +198,8 @@ public class OracleDialect extends Dialect {
super(version);
autonomous = false;
extended = false;
driverMajorVersion = 19;
driverMinorVersion = 0;
}
public OracleDialect(DialectResolutionInfo info) {
@ -203,6 +210,8 @@ public class OracleDialect extends Dialect {
super( info );
autonomous = serverConfiguration.isAutonomous();
extended = serverConfiguration.isExtended();
this.driverMinorVersion = serverConfiguration.getDriverMinorVersion();
this.driverMajorVersion = serverConfiguration.getDriverMajorVersion();
}
@Deprecated( since = "6.4" )
@ -1621,4 +1630,12 @@ public class OracleDialect extends Dialect {
public boolean supportsFromClauseInUpdate() {
return true;
}
public int getDriverMajorVersion() {
return driverMajorVersion;
}
public int getDriverMinorVersion() {
return driverMinorVersion;
}
}

View File

@ -26,6 +26,8 @@ import static org.hibernate.cfg.DialectSpecificSettings.ORACLE_EXTENDED_STRING_S
public class OracleServerConfiguration {
private final boolean autonomous;
private final boolean extended;
private final int driverMajorVersion;
private final int driverMinorVersion;
public boolean isAutonomous() {
return autonomous;
@ -35,16 +37,39 @@ public class OracleServerConfiguration {
return extended;
}
public int getDriverMajorVersion() {
return driverMajorVersion;
}
public int getDriverMinorVersion() {
return driverMinorVersion;
}
public OracleServerConfiguration(boolean autonomous, boolean extended) {
this( autonomous, extended, 19, 0 );
}
public OracleServerConfiguration(
boolean autonomous,
boolean extended,
int driverMajorVersion,
int driverMinorVersion) {
this.autonomous = autonomous;
this.extended = extended;
this.driverMajorVersion = driverMajorVersion;
this.driverMinorVersion = driverMinorVersion;
}
public static OracleServerConfiguration fromDialectResolutionInfo(DialectResolutionInfo info) {
Boolean extended = null;
Boolean autonomous = null;
Integer majorVersion = null;
Integer minorVersion = null;
final DatabaseMetaData databaseMetaData = info.getDatabaseMetadata();
if ( databaseMetaData != null ) {
majorVersion = databaseMetaData.getDriverMajorVersion();
minorVersion = databaseMetaData.getDriverMinorVersion();
try (final Statement statement = databaseMetaData.getConnection().createStatement()) {
final ResultSet rs = statement.executeQuery(
"select cast('string' as varchar2(32000)), " +
@ -77,7 +102,19 @@ public class OracleServerConfiguration {
false
);
}
return new OracleServerConfiguration( autonomous, extended );
if ( majorVersion == null ) {
try {
java.sql.Driver driver = java.sql.DriverManager.getDriver( "jdbc:oracle:thin:" );
majorVersion = driver.getMajorVersion();
minorVersion = driver.getMinorVersion();
}
catch (SQLException ex) {
majorVersion = 19;
minorVersion = 0;
}
}
return new OracleServerConfiguration( autonomous, extended, majorVersion, minorVersion );
}
private static boolean isAutonomous(String cloudServiceParam) {
@ -86,11 +123,14 @@ public class OracleServerConfiguration {
private static boolean isAutonomous(DatabaseMetaData databaseMetaData) {
try (final Statement statement = databaseMetaData.getConnection().createStatement()) {
return statement.executeQuery( "select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" ).next();
return statement.executeQuery(
"select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" )
.next();
}
catch (SQLException ex) {
// Ignore
}
return false;
}
}

View File

@ -12,4 +12,9 @@ package org.hibernate.dialect;
public class OracleTypes {
public static final int CURSOR = -10;
public static final int JSON = 2016;
public static final int VECTOR = -105;
public static final int VECTOR_INT8 = -106;
public static final int VECTOR_FLOAT32 = -107;
public static final int VECTOR_FLOAT64 = -108;
}

View File

@ -211,7 +211,6 @@ public class SqlTypes {
* as a synonym for {@link #VARBINARY}.
*
* @see org.hibernate.Length#LONG
*
* @see Types#LONGVARBINARY
* @see org.hibernate.type.descriptor.jdbc.LongVarbinaryJdbcType
*/
@ -356,7 +355,6 @@ public class SqlTypes {
* as a synonym for {@link #NVARCHAR}.
*
* @see org.hibernate.Length#LONG
*
* @see Types#LONGNVARCHAR
* @see org.hibernate.type.descriptor.jdbc.LongNVarcharJdbcType
*/
@ -657,13 +655,28 @@ public class SqlTypes {
/**
* A type code representing an {@code embedding vector} type for databases like
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} that have special extensions.
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
* An embedding vector essentially is a {@code float[]} with a fixed size.
*
* @since 6.4
*/
public static final int VECTOR = 10_000;
/**
* A type code representing a single-byte integer vector type for oracle 23ai database.
*/
public static final int VECTOR_INT8 = 10_001;
/**
* A type code representing a single-precision floating-point vector type for oracle 23ai database.
*/
public static final int VECTOR_FLOAT32 = 10_002;
/**
* A type code representing a double-precision floating-point type for oracle 23ai database.
*/
public static final int VECTOR_FLOAT64 = 10_003;
private SqlTypes() {
}
@ -693,6 +706,7 @@ public class SqlTypes {
/**
* Is this a type with a length, that is, is it
* some kind of character string or binary string?
*
* @param typeCode a JDBC type code from {@link Types}
*/
public static boolean isStringType(int typeCode) {
@ -715,6 +729,7 @@ public class SqlTypes {
/**
* Does the given JDBC type code represent some sort of
* character string type?
*
* @param typeCode a JDBC type code from {@link Types}
*/
public static boolean isCharacterOrClobType(int typeCode) {
@ -736,6 +751,7 @@ public class SqlTypes {
/**
* Does the given JDBC type code represent some sort of
* character string type?
*
* @param typeCode a JDBC type code from {@link Types}
*/
public static boolean isCharacterType(int typeCode) {
@ -755,6 +771,7 @@ public class SqlTypes {
/**
* Does the given JDBC type code represent some sort of
* variable-length character string type?
*
* @param typeCode a JDBC type code from {@link Types}
*/
public static boolean isVarcharType(int typeCode) {

View File

@ -551,7 +551,6 @@ public final class StandardBasicTypes {
);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Binary mappings
@ -746,12 +745,36 @@ public final class StandardBasicTypes {
/**
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector.
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector and Oracle 23ai.
*/
public static final BasicTypeReference<float[]> VECTOR = new BasicTypeReference<>(
"vector", float[].class, SqlTypes.VECTOR
);
/**
* The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8},
* specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai.
*/
public static final BasicTypeReference<byte[]> VECTOR_INT8 = new BasicTypeReference<>(
"byte_vector", byte[].class, SqlTypes.VECTOR_INT8
);
/**
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
* specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai.
*/
public static final BasicTypeReference<float[]> VECTOR_FLOAT32 = new BasicTypeReference<>(
"float_vector", float[].class, SqlTypes.VECTOR_FLOAT32
);
/**
* The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
* specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai.
*/
public static final BasicTypeReference<double[]> VECTOR_FLOAT64 = new BasicTypeReference<>(
"double_vector", double[].class, SqlTypes.VECTOR_FLOAT64
);
public static void prime(TypeConfiguration typeConfiguration) {
BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
@ -1252,6 +1275,27 @@ public final class StandardBasicTypes {
"vector"
);
handle(
VECTOR_FLOAT32,
null,
basicTypeRegistry,
"float_vector"
);
handle(
VECTOR_FLOAT64,
null,
basicTypeRegistry,
"double_vector"
);
handle(
VECTOR_INT8,
null,
basicTypeRegistry,
"byte_vector"
);
// Specialized version handlers

View File

@ -0,0 +1,155 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import org.hibernate.dialect.Dialect;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.ValueBinder;
import org.hibernate.type.descriptor.ValueExtractor;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
import org.hibernate.type.descriptor.java.ByteJavaType;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.BasicBinder;
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.internal.JdbcLiteralFormatterArray;
/**
* Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle.
* <p>
* This class handles generic vectors represented by an asterisk (*) in the format,
* allowing for different element types within the vector.
*
* @author Hassan AL Meftah
*/
public abstract class AbstractOracleVectorJdbcType extends ArrayJdbcType {
final boolean isVectorSupported;
public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
super( elementJdbcType );
this.isVectorSupported = isVectorSupported;
}
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR;
}
@Override
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
final JavaType<T> elementJavaType;
if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) {
// Special handling needed for Byte[], because that would conflict with the VARBINARY mapping
//noinspection unchecked
elementJavaType = (JavaType<T>) ByteJavaType.INSTANCE;
}
else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) {
//noinspection unchecked
elementJavaType = ( (BasicPluralJavaType<T>) javaTypeDescriptor ).getElementJavaType();
}
else {
throw new IllegalArgumentException( "not a BasicPluralJavaType" );
}
return new JdbcLiteralFormatterArray<>(
javaTypeDescriptor,
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType )
);
}
@Override
public String toString() {
return "OracleVectorTypeDescriptor";
}
@Override
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
return new BasicBinder<>( javaTypeDescriptor, this ) {
@Override
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options)
throws SQLException {
if ( isVectorSupported ) {
st.setObject( index, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() );
}
else {
st.setString( index, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) );
}
}
@Override
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
throws SQLException {
if ( isVectorSupported ) {
st.setObject( name, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() );
}
else {
st.setString( name, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) );
}
}
};
}
@Override
public <X> ValueExtractor<X> getExtractor(final JavaType<X> javaTypeDescriptor) {
return new BasicExtractor<>( javaTypeDescriptor, this ) {
@Override
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
if ( isVectorSupported ) {
return getJavaType().wrap( rs.getObject( paramIndex, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
}
else {
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( rs.getString( paramIndex ) ), options );
}
}
@Override
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
if ( isVectorSupported ) {
return getJavaType().wrap( statement.getObject( index, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
}
else {
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( index ) ), options );
}
}
@Override
protected X doExtract(CallableStatement statement, String name, WrapperOptions options)
throws SQLException {
if ( isVectorSupported ) {
return getJavaType().wrap( statement.getObject( name, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
}
else {
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( name ) ), options );
}
}
};
}
protected abstract <T> T getVectorArray(String string);
protected abstract <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options);
protected abstract Class<?> getNativeJavaType();
protected abstract int getNativeTypeCode();
}

View File

@ -0,0 +1,93 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.Arrays;
import java.util.BitSet;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleTypes;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
/**
* Specialized type mapping for single-byte integer vector {@link SqlTypes#VECTOR_INT8} SQL data type for Oracle.
*
* @author Hassan AL Meftah
*/
public class OracleByteVectorJdbcType extends AbstractOracleVectorJdbcType {
private static final byte[] EMPTY = new byte[0];
public OracleByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
super( elementJdbcType, isVectorSupported );
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( "to_vector(" );
appender.append( writeExpression );
appender.append( ", *, INT8)" );
}
@Override
public String getFriendlyName() {
return "VECTOR_INT8";
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR_INT8;
}
@Override
protected byte[] getVectorArray(String string) {
if ( string == null ) {
return null;
}
if ( string.length() == 2 ) {
return EMPTY;
}
final BitSet commaPositions = new BitSet();
int size = 1;
for ( int i = 1; i < string.length(); i++ ) {
final char c = string.charAt( i );
if ( c == ',' ) {
commaPositions.set( i );
size++;
}
}
final byte[] result = new byte[size];
int doubleStartIndex = 1;
int commaIndex;
int index = 0;
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) );
doubleStartIndex = commaIndex + 1;
}
result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) );
return result;
}
@Override
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) );
}
protected Class<?> getNativeJavaType(){
return byte[].class;
};
protected int getNativeTypeCode(){
return OracleTypes.VECTOR_INT8;
};
}

View File

@ -0,0 +1,93 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.Arrays;
import java.util.BitSet;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleTypes;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
/**
* Specialized type mapping for double-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT64} SQL data type for Oracle.
*
* @author Hassan AL Meftah
*/
public class OracleDoubleVectorJdbcType extends AbstractOracleVectorJdbcType {
private static final double[] EMPTY = new double[0];
public OracleDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
super( elementJdbcType, isVectorSupported );
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( "to_vector(" );
appender.append( writeExpression );
appender.append( ", *, FLOAT64)" );
}
@Override
public String getFriendlyName() {
return "VECTOR_FLOAT64";
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR_FLOAT64;
}
@Override
protected double[] getVectorArray(String string) {
if ( string == null ) {
return null;
}
if ( string.length() == 2 ) {
return EMPTY;
}
final BitSet commaPositions = new BitSet();
int size = 1;
for ( int i = 1; i < string.length(); i++ ) {
final char c = string.charAt( i );
if ( c == ',' ) {
commaPositions.set( i );
size++;
}
}
final double[] result = new double[size];
int doubleStartIndex = 1;
int commaIndex;
int index = 0;
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) );
doubleStartIndex = commaIndex + 1;
}
result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) );
return result;
}
@Override
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
return Arrays.toString( javaTypeDescriptor.unwrap( vector, double[].class, options ) );
}
protected Class<?> getNativeJavaType() {
return double[].class;
}
@Override
protected int getNativeTypeCode() {
return OracleTypes.VECTOR_FLOAT64;
}
}

View File

@ -0,0 +1,93 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.Arrays;
import java.util.BitSet;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleTypes;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
/**
* Specialized type mapping for single-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT32} SQL data type for Oracle.
*
* @author Hassan AL Meftah
*/
public class OracleFloatVectorJdbcType extends AbstractOracleVectorJdbcType {
private static final float[] EMPTY = new float[0];
public OracleFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
super( elementJdbcType, isVectorSupported );
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( "to_vector(" );
appender.append( writeExpression );
appender.append( ", *, FLOAT32)" );
}
@Override
public String getFriendlyName() {
return "VECTOR_FLOAT32";
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR_FLOAT32;
}
@Override
protected float[] getVectorArray(String string) {
if ( string == null ) {
return null;
}
if ( string.length() == 2 ) {
return EMPTY;
}
final BitSet commaPositions = new BitSet();
int size = 1;
for ( int i = 1; i < string.length(); i++ ) {
final char c = string.charAt( i );
if ( c == ',' ) {
commaPositions.set( i );
size++;
}
}
final float[] result = new float[size];
int doubleStartIndex = 1;
int commaIndex;
int index = 0;
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) );
doubleStartIndex = commaIndex + 1;
}
result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) );
return result;
}
@Override
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
return Arrays.toString( javaTypeDescriptor.unwrap( vector, float[].class, options ) );
}
protected Class<?> getNativeJavaType() {
return float[].class;
}
protected int getNativeTypeCode() {
return OracleTypes.VECTOR_FLOAT32;
}
}

View File

@ -0,0 +1,107 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
public class OracleVectorFunctionContributor implements FunctionContributor {
@Override
public void contributeFunctions(FunctionContributions functionContributions) {
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final Dialect dialect = functionContributions.getDialect();
if ( dialect instanceof OracleDialect ) {
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vector_distance(?1, ?2, COSINE)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
functionRegistry.patternDescriptorBuilder( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)")
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "vector_distance(?1, ?2, DOT)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "inner_product", "vector_distance(?1, ?2, DOT)*-1" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "hamming_distance", "vector_distance(?1, ?2, HAMMING)" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.namedDescriptorBuilder( "vector_dims" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 1 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
.register();
functionRegistry.namedDescriptorBuilder( "vector_norm" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 1 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
}
}
@Override
public int ordinal() {
return 200;
}
}

View File

@ -0,0 +1,51 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleTypes;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.jdbc.JdbcType;
/**
* Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle.
* <p>
* This class handles generic vectors represented by an asterisk (*) in the format,
* allowing for different element types within the vector.
*
* @author Hassan AL Meftah
*/
public class OracleVectorJdbcType extends OracleFloatVectorJdbcType {
public OracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
super( elementJdbcType, isVectorSupported );
}
@Override
public String getFriendlyName() {
return "VECTOR";
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( "to_vector(" );
appender.append( writeExpression );
appender.append( ", *, *)" );
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR;
}
@Override
protected int getNativeTypeCode() {
return OracleTypes.VECTOR;
}
}

View File

@ -0,0 +1,167 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.sql.SQLException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.boot.model.TypeContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.internal.util.StringHelper;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.type.BasicArrayType;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeReference;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
import org.hibernate.type.spi.TypeConfiguration;
public class OracleVectorTypeContributor implements TypeContributor {
@Override
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class )
.getDialect();
if ( dialect instanceof OracleDialect && dialect.getVersion().isSameOrAfter( 23 ) ) {
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final boolean isVectorSupported = isVectorSupportedByDriver( (OracleDialect) dialect );
// Register generic vector type
final OracleVectorJdbcType genericVectorJdbcType = new OracleVectorJdbcType(
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
isVectorSupported
);
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType );
final JdbcType floatVectorJdbcType = new OracleFloatVectorJdbcType(
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
isVectorSupported
);
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType );
final JdbcType doubleVectorJdbcType = new OracleDoubleVectorJdbcType(
jdbcTypeRegistry.getDescriptor( SqlTypes.DOUBLE ),
isVectorSupported
);
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT64, doubleVectorJdbcType );
final JdbcType byteVectorJdbcType = new OracleByteVectorJdbcType(
jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ),
isVectorSupported
);
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType );
// Resolving basic types after jdbc types are registered.
basicTypeRegistry.register(
new BasicArrayType<>(
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
genericVectorJdbcType,
javaTypeRegistry.getDescriptor( float[].class )
),
StandardBasicTypes.VECTOR.getName()
);
basicTypeRegistry.register(
new BasicArrayType<>(
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
floatVectorJdbcType,
javaTypeRegistry.getDescriptor( float[].class )
),
StandardBasicTypes.VECTOR_FLOAT32.getName()
);
basicTypeRegistry.register(
new BasicArrayType<>(
basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ),
doubleVectorJdbcType,
javaTypeRegistry.getDescriptor( double[].class )
),
StandardBasicTypes.VECTOR_FLOAT64.getName()
);
basicTypeRegistry.register(
new BasicArrayType<>(
basicTypeRegistry.resolve( StandardBasicTypes.BYTE ),
byteVectorJdbcType,
javaTypeRegistry.getDescriptor( byte[].class )
),
StandardBasicTypes.VECTOR_INT8.getName()
);
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l, *)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return OracleVectorTypeContributor.replace(
"vector($l, *)",
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
);
}
}
);
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR_INT8, "vector($l, INT8)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return OracleVectorTypeContributor.replace(
"vector($l, INT8)",
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
);
}
}
);
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR_FLOAT32, "vector($l, FLOAT32)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return OracleVectorTypeContributor.replace(
"vector($l, FLOAT32)",
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
);
}
}
);
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR_FLOAT64, "vector($l, FLOAT64)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return OracleVectorTypeContributor.replace(
"vector($l, FLOAT64)",
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
);
}
}
);
}
}
/**
* Replace vector dimension with the length or * for undefined length
*/
private static String replace(String type, Long size) {
return StringHelper.replaceOnce( type, "$l", size != null ? size.toString() : "*" );
}
private boolean isVectorSupportedByDriver(OracleDialect dialect) {
int majorVersion = dialect.getDriverMajorVersion();
int minorVersion = dialect.getDriverMinorVersion();
return ( majorVersion > 23 ) || ( majorVersion == 23 && minorVersion >= 4 );
}
}

View File

@ -14,6 +14,7 @@ import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.BasicType;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.spi.TypeConfiguration;
@ -24,6 +25,13 @@ public class VectorArgumentValidator implements ArgumentsValidator {
public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator();
private static final int[] availableVectorCodes = {
SqlTypes.VECTOR,
SqlTypes.VECTOR_INT8,
SqlTypes.VECTOR_FLOAT32,
SqlTypes.VECTOR_FLOAT64
};
@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
@ -46,7 +54,18 @@ public class VectorArgumentValidator implements ArgumentsValidator {
}
private static boolean isVectorType(SqmExpressible<?> vectorType) {
return vectorType instanceof BasicPluralType<?, ?>
&& ( (BasicPluralType<?, ?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() == SqlTypes.VECTOR;
if ( !( vectorType instanceof BasicPluralType<?, ?> ) ) {
return false;
}
switch ( ( (BasicType<?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() ) {
case SqlTypes.VECTOR:
case SqlTypes.VECTOR_INT8:
case SqlTypes.VECTOR_FLOAT32:
case SqlTypes.VECTOR_FLOAT64:
return true;
default:
return false;
}
}
}

View File

@ -1 +1,2 @@
org.hibernate.vector.PGVectorFunctionContributor
org.hibernate.vector.OracleVectorFunctionContributor

View File

@ -1 +1,2 @@
org.hibernate.vector.PGVectorTypeContributor
org.hibernate.vector.OracleVectorTypeContributor

View File

@ -0,0 +1,280 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.List;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.type.SqlTypes;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Hassan AL Meftah
*/
@DomainModel(annotatedClasses = OracleByteVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
public class OracleByteVectorTest {
private static final byte[] V1 = new byte[]{ 1, 2, 3 };
private static final byte[] V2 = new byte[]{ 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new byte[]{ 1, 2, 3 }, tableRecord.getTheVector() );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new byte[]{ 4, 5, 6 }, tableRecord.getTheVector() );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final byte[] vector = new byte[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector, byte[].class )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final byte[] vector = new byte[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
} );
}
@Test
public void testTaxicabDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::taxicab-distance-example[]
final byte[] vector = new byte[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::taxicab-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testInnerProduct(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final byte[] vector = new byte[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
} );
}
@Test
public void testHammingDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final byte[] vector = new byte[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testVectorDims(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-dims-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-dims-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
} );
}
@Test
public void testVectorNorm(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-norm-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-norm-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
private static double cosineDistance(byte[] f1, byte[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(byte[] f1, byte[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double taxicabDistance(byte[] f1, byte[] f2) {
return norm( f1 ) - norm( f2 );
}
private static double innerProduct(byte[] f1, byte[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
public static double hammingDistance(byte[] f1, byte[] f2) {
assert f1.length == f2.length;
int distance = 0;
for (int i = 0; i < f1.length; i++) {
if (!(f1[i] == f2[i])) {
distance++;
}
}
return distance;
}
private static double euclideanNorm(byte[] f) {
double result = 0;
for ( double v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
private static double norm(byte[] f) {
double result = 0;
for ( double v : f ) {
result += Math.abs( v );
}
return result;
}
@Entity( name = "VectorEntity" )
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column( name = "the_vector" )
@JdbcTypeCode(SqlTypes.VECTOR_INT8)
@Array(length = 3)
private byte[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, byte[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public byte[] getTheVector() {
return theVector;
}
public void setTheVector(byte[] theVector) {
this.theVector = theVector;
}
}
}

View File

@ -0,0 +1,277 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.List;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.type.SqlTypes;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Hassan AL Meftah
*/
@DomainModel(annotatedClasses = OracleDoubleVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
public class OracleDoubleVectorTest {
private static final double[] V1 = new double[]{ 1, 2, 3 };
private static final double[] V2 = new double[]{ 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new double[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new double[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final double[] vector = new double[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final double[] vector = new double[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.00002D);
} );
}
@Test
public void testTaxicabDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::taxicab-distance-example[]
final double[] vector = new double[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::taxicab-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testInnerProduct(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final double[] vector = new double[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
} );
}
@Test
public void testHammingDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final double[] vector = new double[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testVectorDims(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-dims-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-dims-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
} );
}
@Test
public void testVectorNorm(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-norm-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-norm-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
private static double cosineDistance(double[] f1, double[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(double[] f1, double[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double taxicabDistance(double[] f1, double[] f2) {
return norm( f1 ) - norm( f2 );
}
public static double hammingDistance(double[] f1, double[] f2) {
assert f1.length == f2.length;
int distance = 0;
for (int i = 0; i < f1.length; i++) {
if (!(f1[i] == f2[i])) {
distance++;
}
}
return distance;
}
private static double innerProduct(double[] f1, double[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
private static double euclideanNorm(double[] f) {
double result = 0;
for ( double v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
private static double norm(double[] f) {
double result = 0;
for ( double v : f ) {
result += Math.abs( v );
}
return result;
}
@Entity( name = "VectorEntity" )
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column( name = "the_vector" )
@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)
@Array(length = 3)
private double[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, double[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public double[] getTheVector() {
return theVector;
}
public void setTheVector(double[] theVector) {
this.theVector = theVector;
}
}
}

View File

@ -0,0 +1,301 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.List;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.type.SqlTypes;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Hassan AL Meftah
*/
@DomainModel(annotatedClasses = OracleFloatVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
public class OracleFloatVectorTest {
private static final float[] V1 = new float[] { 1, 2, 3 };
private static final float[] V2 = new float[] { 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
} );
}
@Test
public void testTaxicabDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::taxicab-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::taxicab-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testInnerProduct(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
} );
}
@Test
public void testHammingDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testVectorDims(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-dims-example[]
final List<Tuple> results = em.createSelectionQuery(
"select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id",
Tuple.class
)
.getResultList();
//end::vector-dims-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
} );
}
@Test
public void testVectorNorm(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-norm-example[]
final List<Tuple> results = em.createSelectionQuery(
"select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id",
Tuple.class
)
.getResultList();
//end::vector-norm-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
private static double cosineDistance(float[] f1, float[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double taxicabDistance(float[] f1, float[] f2) {
return norm( f1 ) - norm( f2 );
}
private static double innerProduct(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
public static double hammingDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
int distance = 0;
for ( int i = 0; i < f1.length; i++ ) {
if ( !( f1[i] == f2[i] ) ) {
distance++;
}
}
return distance;
}
private static double euclideanNorm(float[] f) {
double result = 0;
for ( double v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
private static double norm(float[] f) {
double result = 0;
for ( double v : f ) {
result += Math.abs( v );
}
return result;
}
@Entity(name = "VectorEntity")
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column(name = "the_vector")
@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)
@Array(length = 3)
private float[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, float[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public float[] getTheVector() {
return theVector;
}
public void setTheVector(float[] theVector) {
this.theVector = theVector;
}
}
}

View File

@ -0,0 +1,302 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.vector;
import java.util.List;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.OracleDialect;
import org.hibernate.type.SqlTypes;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Hassan AL Meftah
*/
@DomainModel(annotatedClasses = OracleGenericVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
public class OracleGenericVectorTest {
private static final float[] V1 = new float[] { 1, 2, 3 };
private static final float[] V2 = new float[] { 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector() );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector() );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
} );
}
@Test
public void testTaxicabDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::taxicab-distance-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::taxicab-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testInnerProduct(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
} );
}
@Test
public void testHammingDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final float[] vector = new float[] { 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery(
"select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id",
Tuple.class
)
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
@Test
public void testVectorDims(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-dims-example[]
final List<Tuple> results = em.createSelectionQuery(
"select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id",
Tuple.class
)
.getResultList();
//end::vector-dims-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
} );
}
@Test
public void testVectorNorm(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-norm-example[]
final List<Tuple> results = em.createSelectionQuery(
"select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id",
Tuple.class
)
.getResultList();
//end::vector-norm-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
} );
}
private static double cosineDistance(float[] f1, float[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double taxicabDistance(float[] f1, float[] f2) {
return norm( f1 ) - norm( f2 );
}
private static double innerProduct(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
public static double hammingDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
int distance = 0;
for ( int i = 0; i < f1.length; i++ ) {
if ( !( f1[i] == f2[i] ) ) {
distance++;
}
}
return distance;
}
private static double euclideanNorm(float[] f) {
double result = 0;
for ( double v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
private static double norm(float[] f) {
double result = 0;
for ( double v : f ) {
result += Math.abs( v );
}
return result;
}
@Entity(name = "VectorEntity")
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column(name = "the_vector")
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 3)
private float[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, float[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public float[] getTheVector() {
return theVector;
}
public void setTheVector(float[] theVector) {
this.theVector = theVector;
}
}
}