HHH-17738 : Add support for Oracle database AI Vector Search
This commit is contained in:
parent
4791b41cf5
commit
60b0a63629
|
@ -185,6 +185,11 @@ public class OracleDialect extends Dialect {
|
|||
// Is MAX_STRING_SIZE set to EXTENDED?
|
||||
protected final boolean extended;
|
||||
|
||||
protected final int driverMajorVersion;
|
||||
|
||||
protected final int driverMinorVersion;
|
||||
|
||||
|
||||
public OracleDialect() {
|
||||
this( MINIMUM_VERSION );
|
||||
}
|
||||
|
@ -193,6 +198,8 @@ public class OracleDialect extends Dialect {
|
|||
super(version);
|
||||
autonomous = false;
|
||||
extended = false;
|
||||
driverMajorVersion = 19;
|
||||
driverMinorVersion = 0;
|
||||
}
|
||||
|
||||
public OracleDialect(DialectResolutionInfo info) {
|
||||
|
@ -203,6 +210,8 @@ public class OracleDialect extends Dialect {
|
|||
super( info );
|
||||
autonomous = serverConfiguration.isAutonomous();
|
||||
extended = serverConfiguration.isExtended();
|
||||
this.driverMinorVersion = serverConfiguration.getDriverMinorVersion();
|
||||
this.driverMajorVersion = serverConfiguration.getDriverMajorVersion();
|
||||
}
|
||||
|
||||
@Deprecated( since = "6.4" )
|
||||
|
@ -1621,4 +1630,12 @@ public class OracleDialect extends Dialect {
|
|||
public boolean supportsFromClauseInUpdate() {
|
||||
return true;
|
||||
}
|
||||
|
||||
public int getDriverMajorVersion() {
|
||||
return driverMajorVersion;
|
||||
}
|
||||
|
||||
public int getDriverMinorVersion() {
|
||||
return driverMinorVersion;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,8 @@ import static org.hibernate.cfg.DialectSpecificSettings.ORACLE_EXTENDED_STRING_S
|
|||
public class OracleServerConfiguration {
|
||||
private final boolean autonomous;
|
||||
private final boolean extended;
|
||||
private final int driverMajorVersion;
|
||||
private final int driverMinorVersion;
|
||||
|
||||
public boolean isAutonomous() {
|
||||
return autonomous;
|
||||
|
@ -35,16 +37,39 @@ public class OracleServerConfiguration {
|
|||
return extended;
|
||||
}
|
||||
|
||||
public int getDriverMajorVersion() {
|
||||
return driverMajorVersion;
|
||||
}
|
||||
|
||||
public int getDriverMinorVersion() {
|
||||
return driverMinorVersion;
|
||||
}
|
||||
|
||||
public OracleServerConfiguration(boolean autonomous, boolean extended) {
|
||||
this( autonomous, extended, 19, 0 );
|
||||
}
|
||||
|
||||
public OracleServerConfiguration(
|
||||
boolean autonomous,
|
||||
boolean extended,
|
||||
int driverMajorVersion,
|
||||
int driverMinorVersion) {
|
||||
this.autonomous = autonomous;
|
||||
this.extended = extended;
|
||||
this.driverMajorVersion = driverMajorVersion;
|
||||
this.driverMinorVersion = driverMinorVersion;
|
||||
}
|
||||
|
||||
public static OracleServerConfiguration fromDialectResolutionInfo(DialectResolutionInfo info) {
|
||||
Boolean extended = null;
|
||||
Boolean autonomous = null;
|
||||
Integer majorVersion = null;
|
||||
Integer minorVersion = null;
|
||||
final DatabaseMetaData databaseMetaData = info.getDatabaseMetadata();
|
||||
if ( databaseMetaData != null ) {
|
||||
majorVersion = databaseMetaData.getDriverMajorVersion();
|
||||
minorVersion = databaseMetaData.getDriverMinorVersion();
|
||||
|
||||
try (final Statement statement = databaseMetaData.getConnection().createStatement()) {
|
||||
final ResultSet rs = statement.executeQuery(
|
||||
"select cast('string' as varchar2(32000)), " +
|
||||
|
@ -77,7 +102,19 @@ public class OracleServerConfiguration {
|
|||
false
|
||||
);
|
||||
}
|
||||
return new OracleServerConfiguration( autonomous, extended );
|
||||
if ( majorVersion == null ) {
|
||||
try {
|
||||
java.sql.Driver driver = java.sql.DriverManager.getDriver( "jdbc:oracle:thin:" );
|
||||
majorVersion = driver.getMajorVersion();
|
||||
minorVersion = driver.getMinorVersion();
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
majorVersion = 19;
|
||||
minorVersion = 0;
|
||||
}
|
||||
|
||||
}
|
||||
return new OracleServerConfiguration( autonomous, extended, majorVersion, minorVersion );
|
||||
}
|
||||
|
||||
private static boolean isAutonomous(String cloudServiceParam) {
|
||||
|
@ -86,11 +123,14 @@ public class OracleServerConfiguration {
|
|||
|
||||
private static boolean isAutonomous(DatabaseMetaData databaseMetaData) {
|
||||
try (final Statement statement = databaseMetaData.getConnection().createStatement()) {
|
||||
return statement.executeQuery( "select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" ).next();
|
||||
return statement.executeQuery(
|
||||
"select 1 from dual where sys_context('USERENV','CLOUD_SERVICE') in ('OLTP','DWCS','JDCS')" )
|
||||
.next();
|
||||
}
|
||||
catch (SQLException ex) {
|
||||
// Ignore
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -12,4 +12,9 @@ package org.hibernate.dialect;
|
|||
public class OracleTypes {
|
||||
public static final int CURSOR = -10;
|
||||
public static final int JSON = 2016;
|
||||
|
||||
public static final int VECTOR = -105;
|
||||
public static final int VECTOR_INT8 = -106;
|
||||
public static final int VECTOR_FLOAT32 = -107;
|
||||
public static final int VECTOR_FLOAT64 = -108;
|
||||
}
|
||||
|
|
|
@ -211,7 +211,6 @@ public class SqlTypes {
|
|||
* as a synonym for {@link #VARBINARY}.
|
||||
*
|
||||
* @see org.hibernate.Length#LONG
|
||||
*
|
||||
* @see Types#LONGVARBINARY
|
||||
* @see org.hibernate.type.descriptor.jdbc.LongVarbinaryJdbcType
|
||||
*/
|
||||
|
@ -356,7 +355,6 @@ public class SqlTypes {
|
|||
* as a synonym for {@link #NVARCHAR}.
|
||||
*
|
||||
* @see org.hibernate.Length#LONG
|
||||
*
|
||||
* @see Types#LONGNVARCHAR
|
||||
* @see org.hibernate.type.descriptor.jdbc.LongNVarcharJdbcType
|
||||
*/
|
||||
|
@ -657,13 +655,28 @@ public class SqlTypes {
|
|||
|
||||
/**
|
||||
* A type code representing an {@code embedding vector} type for databases like
|
||||
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} that have special extensions.
|
||||
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} and {@link org.hibernate.dialect.OracleDialect Oracle 23ai}.
|
||||
* An embedding vector essentially is a {@code float[]} with a fixed size.
|
||||
*
|
||||
* @since 6.4
|
||||
*/
|
||||
public static final int VECTOR = 10_000;
|
||||
|
||||
/**
|
||||
* A type code representing a single-byte integer vector type for oracle 23ai database.
|
||||
*/
|
||||
public static final int VECTOR_INT8 = 10_001;
|
||||
|
||||
/**
|
||||
* A type code representing a single-precision floating-point vector type for oracle 23ai database.
|
||||
*/
|
||||
public static final int VECTOR_FLOAT32 = 10_002;
|
||||
|
||||
/**
|
||||
* A type code representing a double-precision floating-point type for oracle 23ai database.
|
||||
*/
|
||||
public static final int VECTOR_FLOAT64 = 10_003;
|
||||
|
||||
private SqlTypes() {
|
||||
}
|
||||
|
||||
|
@ -693,6 +706,7 @@ public class SqlTypes {
|
|||
/**
|
||||
* Is this a type with a length, that is, is it
|
||||
* some kind of character string or binary string?
|
||||
*
|
||||
* @param typeCode a JDBC type code from {@link Types}
|
||||
*/
|
||||
public static boolean isStringType(int typeCode) {
|
||||
|
@ -715,6 +729,7 @@ public class SqlTypes {
|
|||
/**
|
||||
* Does the given JDBC type code represent some sort of
|
||||
* character string type?
|
||||
*
|
||||
* @param typeCode a JDBC type code from {@link Types}
|
||||
*/
|
||||
public static boolean isCharacterOrClobType(int typeCode) {
|
||||
|
@ -736,6 +751,7 @@ public class SqlTypes {
|
|||
/**
|
||||
* Does the given JDBC type code represent some sort of
|
||||
* character string type?
|
||||
*
|
||||
* @param typeCode a JDBC type code from {@link Types}
|
||||
*/
|
||||
public static boolean isCharacterType(int typeCode) {
|
||||
|
@ -755,6 +771,7 @@ public class SqlTypes {
|
|||
/**
|
||||
* Does the given JDBC type code represent some sort of
|
||||
* variable-length character string type?
|
||||
*
|
||||
* @param typeCode a JDBC type code from {@link Types}
|
||||
*/
|
||||
public static boolean isVarcharType(int typeCode) {
|
||||
|
|
|
@ -551,7 +551,6 @@ public final class StandardBasicTypes {
|
|||
);
|
||||
|
||||
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Binary mappings
|
||||
|
||||
|
@ -746,12 +745,36 @@ public final class StandardBasicTypes {
|
|||
|
||||
/**
|
||||
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
|
||||
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector.
|
||||
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector and Oracle 23ai.
|
||||
*/
|
||||
public static final BasicTypeReference<float[]> VECTOR = new BasicTypeReference<>(
|
||||
"vector", float[].class, SqlTypes.VECTOR
|
||||
);
|
||||
|
||||
/**
|
||||
* The standard Hibernate type for mapping {@code byte[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR_INT8 VECTOR_INT8},
|
||||
* specifically for embedding integer vectors (8-bits) like provided by Oracle 23ai.
|
||||
*/
|
||||
public static final BasicTypeReference<byte[]> VECTOR_INT8 = new BasicTypeReference<>(
|
||||
"byte_vector", byte[].class, SqlTypes.VECTOR_INT8
|
||||
);
|
||||
|
||||
/**
|
||||
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
|
||||
* specifically for embedding single-precision floating-point (32-bits) vectors like provided by Oracle 23ai.
|
||||
*/
|
||||
public static final BasicTypeReference<float[]> VECTOR_FLOAT32 = new BasicTypeReference<>(
|
||||
"float_vector", float[].class, SqlTypes.VECTOR_FLOAT32
|
||||
);
|
||||
|
||||
/**
|
||||
* The standard Hibernate type for mapping {@code double[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
|
||||
* specifically for embedding double-precision floating-point (64-bits) vectors like provided by Oracle 23ai.
|
||||
*/
|
||||
public static final BasicTypeReference<double[]> VECTOR_FLOAT64 = new BasicTypeReference<>(
|
||||
"double_vector", double[].class, SqlTypes.VECTOR_FLOAT64
|
||||
);
|
||||
|
||||
|
||||
public static void prime(TypeConfiguration typeConfiguration) {
|
||||
BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
|
@ -1252,6 +1275,27 @@ public final class StandardBasicTypes {
|
|||
"vector"
|
||||
);
|
||||
|
||||
handle(
|
||||
VECTOR_FLOAT32,
|
||||
null,
|
||||
basicTypeRegistry,
|
||||
"float_vector"
|
||||
);
|
||||
|
||||
handle(
|
||||
VECTOR_FLOAT64,
|
||||
null,
|
||||
basicTypeRegistry,
|
||||
"double_vector"
|
||||
);
|
||||
|
||||
handle(
|
||||
VECTOR_INT8,
|
||||
null,
|
||||
basicTypeRegistry,
|
||||
"byte_vector"
|
||||
);
|
||||
|
||||
|
||||
// Specialized version handlers
|
||||
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.sql.CallableStatement;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.ValueBinder;
|
||||
import org.hibernate.type.descriptor.ValueExtractor;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.BasicPluralJavaType;
|
||||
import org.hibernate.type.descriptor.java.ByteJavaType;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.java.PrimitiveByteArrayJavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.BasicBinder;
|
||||
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.internal.JdbcLiteralFormatterArray;
|
||||
|
||||
/**
|
||||
* Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle.
|
||||
* <p>
|
||||
* This class handles generic vectors represented by an asterisk (*) in the format,
|
||||
* allowing for different element types within the vector.
|
||||
*
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
public abstract class AbstractOracleVectorJdbcType extends ArrayJdbcType {
|
||||
|
||||
final boolean isVectorSupported;
|
||||
|
||||
public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
|
||||
super( elementJdbcType );
|
||||
this.isVectorSupported = isVectorSupported;
|
||||
}
|
||||
|
||||
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
|
||||
final JavaType<T> elementJavaType;
|
||||
if ( javaTypeDescriptor instanceof PrimitiveByteArrayJavaType ) {
|
||||
// Special handling needed for Byte[], because that would conflict with the VARBINARY mapping
|
||||
//noinspection unchecked
|
||||
elementJavaType = (JavaType<T>) ByteJavaType.INSTANCE;
|
||||
}
|
||||
else if ( javaTypeDescriptor instanceof BasicPluralJavaType ) {
|
||||
//noinspection unchecked
|
||||
elementJavaType = ( (BasicPluralJavaType<T>) javaTypeDescriptor ).getElementJavaType();
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException( "not a BasicPluralJavaType" );
|
||||
}
|
||||
return new JdbcLiteralFormatterArray<>(
|
||||
javaTypeDescriptor,
|
||||
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType )
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "OracleVectorTypeDescriptor";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
|
||||
return new BasicBinder<>( javaTypeDescriptor, this ) {
|
||||
@Override
|
||||
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options)
|
||||
throws SQLException {
|
||||
if ( isVectorSupported ) {
|
||||
st.setObject( index, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() );
|
||||
}
|
||||
else {
|
||||
st.setString( index, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) );
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
|
||||
throws SQLException {
|
||||
if ( isVectorSupported ) {
|
||||
st.setObject( name, value, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getNativeTypeCode() );
|
||||
}
|
||||
else {
|
||||
st.setString( name, ( (AbstractOracleVectorJdbcType) getJdbcType() ).getStringVector( value, getJavaType(), options ) );
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public <X> ValueExtractor<X> getExtractor(final JavaType<X> javaTypeDescriptor) {
|
||||
return new BasicExtractor<>( javaTypeDescriptor, this ) {
|
||||
@Override
|
||||
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
|
||||
if ( isVectorSupported ) {
|
||||
return getJavaType().wrap( rs.getObject( paramIndex, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
|
||||
}
|
||||
else {
|
||||
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( rs.getString( paramIndex ) ), options );
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
|
||||
if ( isVectorSupported ) {
|
||||
return getJavaType().wrap( statement.getObject( index, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
|
||||
}
|
||||
else {
|
||||
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( index ) ), options );
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, String name, WrapperOptions options)
|
||||
throws SQLException {
|
||||
if ( isVectorSupported ) {
|
||||
return getJavaType().wrap( statement.getObject( name, ((AbstractOracleVectorJdbcType) getJdbcType() ).getNativeJavaType() ), options );
|
||||
}
|
||||
else {
|
||||
return getJavaType().wrap( ((AbstractOracleVectorJdbcType) getJdbcType() ).getVectorArray( statement.getString( name ) ), options );
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
protected abstract <T> T getVectorArray(String string);
|
||||
|
||||
protected abstract <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options);
|
||||
|
||||
protected abstract Class<?> getNativeJavaType();
|
||||
|
||||
protected abstract int getNativeTypeCode();
|
||||
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.BitSet;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleTypes;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
/**
|
||||
* Specialized type mapping for single-byte integer vector {@link SqlTypes#VECTOR_INT8} SQL data type for Oracle.
|
||||
*
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
public class OracleByteVectorJdbcType extends AbstractOracleVectorJdbcType {
|
||||
|
||||
|
||||
private static final byte[] EMPTY = new byte[0];
|
||||
|
||||
public OracleByteVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
|
||||
super( elementJdbcType, isVectorSupported );
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( "to_vector(" );
|
||||
appender.append( writeExpression );
|
||||
appender.append( ", *, INT8)" );
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFriendlyName() {
|
||||
return "VECTOR_INT8";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR_INT8;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected byte[] getVectorArray(String string) {
|
||||
if ( string == null ) {
|
||||
return null;
|
||||
}
|
||||
if ( string.length() == 2 ) {
|
||||
return EMPTY;
|
||||
}
|
||||
final BitSet commaPositions = new BitSet();
|
||||
int size = 1;
|
||||
for ( int i = 1; i < string.length(); i++ ) {
|
||||
final char c = string.charAt( i );
|
||||
if ( c == ',' ) {
|
||||
commaPositions.set( i );
|
||||
size++;
|
||||
}
|
||||
}
|
||||
final byte[] result = new byte[size];
|
||||
int doubleStartIndex = 1;
|
||||
int commaIndex;
|
||||
int index = 0;
|
||||
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
|
||||
result[index++] = Byte.parseByte( string.substring( doubleStartIndex, commaIndex ) );
|
||||
doubleStartIndex = commaIndex + 1;
|
||||
}
|
||||
result[index] = Byte.parseByte( string.substring( doubleStartIndex, string.length() - 1 ) );
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
|
||||
return Arrays.toString( javaTypeDescriptor.unwrap( vector, byte[].class, options ) );
|
||||
}
|
||||
|
||||
protected Class<?> getNativeJavaType(){
|
||||
return byte[].class;
|
||||
};
|
||||
|
||||
protected int getNativeTypeCode(){
|
||||
return OracleTypes.VECTOR_INT8;
|
||||
};
|
||||
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.BitSet;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleTypes;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
/**
|
||||
* Specialized type mapping for double-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT64} SQL data type for Oracle.
|
||||
*
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
public class OracleDoubleVectorJdbcType extends AbstractOracleVectorJdbcType {
|
||||
|
||||
private static final double[] EMPTY = new double[0];
|
||||
|
||||
public OracleDoubleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
|
||||
super( elementJdbcType, isVectorSupported );
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( "to_vector(" );
|
||||
appender.append( writeExpression );
|
||||
appender.append( ", *, FLOAT64)" );
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFriendlyName() {
|
||||
return "VECTOR_FLOAT64";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR_FLOAT64;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected double[] getVectorArray(String string) {
|
||||
if ( string == null ) {
|
||||
return null;
|
||||
}
|
||||
if ( string.length() == 2 ) {
|
||||
return EMPTY;
|
||||
}
|
||||
final BitSet commaPositions = new BitSet();
|
||||
int size = 1;
|
||||
for ( int i = 1; i < string.length(); i++ ) {
|
||||
final char c = string.charAt( i );
|
||||
if ( c == ',' ) {
|
||||
commaPositions.set( i );
|
||||
size++;
|
||||
}
|
||||
}
|
||||
final double[] result = new double[size];
|
||||
int doubleStartIndex = 1;
|
||||
int commaIndex;
|
||||
int index = 0;
|
||||
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
|
||||
result[index++] = Double.parseDouble( string.substring( doubleStartIndex, commaIndex ) );
|
||||
doubleStartIndex = commaIndex + 1;
|
||||
}
|
||||
result[index] = Double.parseDouble( string.substring( doubleStartIndex, string.length() - 1 ) );
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
|
||||
return Arrays.toString( javaTypeDescriptor.unwrap( vector, double[].class, options ) );
|
||||
}
|
||||
|
||||
protected Class<?> getNativeJavaType() {
|
||||
return double[].class;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getNativeTypeCode() {
|
||||
return OracleTypes.VECTOR_FLOAT64;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.BitSet;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleTypes;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
/**
|
||||
* Specialized type mapping for single-precision floating-point vector {@link SqlTypes#VECTOR_FLOAT32} SQL data type for Oracle.
|
||||
*
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
|
||||
public class OracleFloatVectorJdbcType extends AbstractOracleVectorJdbcType {
|
||||
|
||||
|
||||
private static final float[] EMPTY = new float[0];
|
||||
|
||||
public OracleFloatVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
|
||||
super( elementJdbcType, isVectorSupported );
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( "to_vector(" );
|
||||
appender.append( writeExpression );
|
||||
appender.append( ", *, FLOAT32)" );
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFriendlyName() {
|
||||
return "VECTOR_FLOAT32";
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR_FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected float[] getVectorArray(String string) {
|
||||
if ( string == null ) {
|
||||
return null;
|
||||
}
|
||||
if ( string.length() == 2 ) {
|
||||
return EMPTY;
|
||||
}
|
||||
final BitSet commaPositions = new BitSet();
|
||||
int size = 1;
|
||||
for ( int i = 1; i < string.length(); i++ ) {
|
||||
final char c = string.charAt( i );
|
||||
if ( c == ',' ) {
|
||||
commaPositions.set( i );
|
||||
size++;
|
||||
}
|
||||
}
|
||||
final float[] result = new float[size];
|
||||
int doubleStartIndex = 1;
|
||||
int commaIndex;
|
||||
int index = 0;
|
||||
while ( ( commaIndex = commaPositions.nextSetBit( doubleStartIndex ) ) != -1 ) {
|
||||
result[index++] = Float.parseFloat( string.substring( doubleStartIndex, commaIndex ) );
|
||||
doubleStartIndex = commaIndex + 1;
|
||||
}
|
||||
result[index] = Float.parseFloat( string.substring( doubleStartIndex, string.length() - 1 ) );
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected <T> String getStringVector(T vector, JavaType<T> javaTypeDescriptor, WrapperOptions options) {
|
||||
return Arrays.toString( javaTypeDescriptor.unwrap( vector, float[].class, options ) );
|
||||
}
|
||||
|
||||
protected Class<?> getNativeJavaType() {
|
||||
return float[].class;
|
||||
}
|
||||
|
||||
protected int getNativeTypeCode() {
|
||||
return OracleTypes.VECTOR_FLOAT32;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import org.hibernate.boot.model.FunctionContributions;
|
||||
import org.hibernate.boot.model.FunctionContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class OracleVectorFunctionContributor implements FunctionContributor {
|
||||
|
||||
@Override
|
||||
public void contributeFunctions(FunctionContributions functionContributions) {
|
||||
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
|
||||
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
final Dialect dialect = functionContributions.getDialect();
|
||||
if ( dialect instanceof OracleDialect ) {
|
||||
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
|
||||
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
|
||||
functionRegistry.patternDescriptorBuilder( "cosine_distance", "vector_distance(?1, ?2, COSINE)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "vector_distance(?1, ?2, EUCLIDEAN)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
|
||||
|
||||
functionRegistry.patternDescriptorBuilder( "l1_distance" , "vector_distance(?1, ?2, MANHATTAN)")
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
|
||||
|
||||
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "vector_distance(?1, ?2, DOT)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "inner_product", "vector_distance(?1, ?2, DOT)*-1" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "hamming_distance", "vector_distance(?1, ?2, HAMMING)" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.namedDescriptorBuilder( "vector_dims" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 1 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
|
||||
.register();
|
||||
functionRegistry.namedDescriptorBuilder( "vector_norm" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 1 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordinal() {
|
||||
return 200;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleTypes;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
/**
|
||||
* Specialized type mapping for generic vector {@link SqlTypes#VECTOR} SQL data type for Oracle.
|
||||
* <p>
|
||||
* This class handles generic vectors represented by an asterisk (*) in the format,
|
||||
* allowing for different element types within the vector.
|
||||
*
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
public class OracleVectorJdbcType extends OracleFloatVectorJdbcType {
|
||||
|
||||
|
||||
public OracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSupported) {
|
||||
super( elementJdbcType, isVectorSupported );
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFriendlyName() {
|
||||
return "VECTOR";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( "to_vector(" );
|
||||
appender.append( writeExpression );
|
||||
appender.append( ", *, *)" );
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int getNativeTypeCode() {
|
||||
return OracleTypes.VECTOR;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.sql.SQLException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import org.hibernate.boot.model.TypeContributions;
|
||||
import org.hibernate.boot.model.TypeContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.engine.jdbc.Size;
|
||||
import org.hibernate.engine.jdbc.spi.JdbcServices;
|
||||
import org.hibernate.internal.util.StringHelper;
|
||||
import org.hibernate.service.ServiceRegistry;
|
||||
import org.hibernate.type.BasicArrayType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeReference;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
|
||||
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class OracleVectorTypeContributor implements TypeContributor {
|
||||
|
||||
|
||||
@Override
|
||||
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
|
||||
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class )
|
||||
.getDialect();
|
||||
|
||||
if ( dialect instanceof OracleDialect && dialect.getVersion().isSameOrAfter( 23 ) ) {
|
||||
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
|
||||
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
|
||||
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
|
||||
final boolean isVectorSupported = isVectorSupportedByDriver( (OracleDialect) dialect );
|
||||
|
||||
// Register generic vector type
|
||||
final OracleVectorJdbcType genericVectorJdbcType = new OracleVectorJdbcType(
|
||||
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
|
||||
isVectorSupported
|
||||
);
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType );
|
||||
final JdbcType floatVectorJdbcType = new OracleFloatVectorJdbcType(
|
||||
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
|
||||
isVectorSupported
|
||||
);
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType );
|
||||
final JdbcType doubleVectorJdbcType = new OracleDoubleVectorJdbcType(
|
||||
jdbcTypeRegistry.getDescriptor( SqlTypes.DOUBLE ),
|
||||
isVectorSupported
|
||||
);
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT64, doubleVectorJdbcType );
|
||||
final JdbcType byteVectorJdbcType = new OracleByteVectorJdbcType(
|
||||
jdbcTypeRegistry.getDescriptor( SqlTypes.TINYINT ),
|
||||
isVectorSupported
|
||||
);
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_INT8, byteVectorJdbcType );
|
||||
|
||||
|
||||
// Resolving basic types after jdbc types are registered.
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
|
||||
genericVectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( float[].class )
|
||||
),
|
||||
StandardBasicTypes.VECTOR.getName()
|
||||
);
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
|
||||
floatVectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( float[].class )
|
||||
),
|
||||
StandardBasicTypes.VECTOR_FLOAT32.getName()
|
||||
);
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE ),
|
||||
doubleVectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( double[].class )
|
||||
),
|
||||
StandardBasicTypes.VECTOR_FLOAT64.getName()
|
||||
);
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
basicTypeRegistry.resolve( StandardBasicTypes.BYTE ),
|
||||
byteVectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( byte[].class )
|
||||
),
|
||||
StandardBasicTypes.VECTOR_INT8.getName()
|
||||
);
|
||||
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l, *)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return OracleVectorTypeContributor.replace(
|
||||
"vector($l, *)",
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR_INT8, "vector($l, INT8)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return OracleVectorTypeContributor.replace(
|
||||
"vector($l, INT8)",
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR_FLOAT32, "vector($l, FLOAT32)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return OracleVectorTypeContributor.replace(
|
||||
"vector($l, FLOAT32)",
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR_FLOAT64, "vector($l, FLOAT64)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return OracleVectorTypeContributor.replace(
|
||||
"vector($l, FLOAT64)",
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue()
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Replace vector dimension with the length or * for undefined length
|
||||
*/
|
||||
private static String replace(String type, Long size) {
|
||||
return StringHelper.replaceOnce( type, "$l", size != null ? size.toString() : "*" );
|
||||
}
|
||||
|
||||
private boolean isVectorSupportedByDriver(OracleDialect dialect) {
|
||||
int majorVersion = dialect.getDriverMajorVersion();
|
||||
int minorVersion = dialect.getDriverMinorVersion();
|
||||
|
||||
return ( majorVersion > 23 ) || ( majorVersion == 23 && minorVersion >= 4 );
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
|
|||
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
|
||||
import org.hibernate.query.sqm.tree.SqmTypedNode;
|
||||
import org.hibernate.type.BasicPluralType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
|
@ -24,6 +25,13 @@ public class VectorArgumentValidator implements ArgumentsValidator {
|
|||
|
||||
public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator();
|
||||
|
||||
private static final int[] availableVectorCodes = {
|
||||
SqlTypes.VECTOR,
|
||||
SqlTypes.VECTOR_INT8,
|
||||
SqlTypes.VECTOR_FLOAT32,
|
||||
SqlTypes.VECTOR_FLOAT64
|
||||
};
|
||||
|
||||
@Override
|
||||
public void validate(
|
||||
List<? extends SqmTypedNode<?>> arguments,
|
||||
|
@ -46,7 +54,18 @@ public class VectorArgumentValidator implements ArgumentsValidator {
|
|||
}
|
||||
|
||||
private static boolean isVectorType(SqmExpressible<?> vectorType) {
|
||||
return vectorType instanceof BasicPluralType<?, ?>
|
||||
&& ( (BasicPluralType<?, ?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() == SqlTypes.VECTOR;
|
||||
if ( !( vectorType instanceof BasicPluralType<?, ?> ) ) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch ( ( (BasicType<?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() ) {
|
||||
case SqlTypes.VECTOR:
|
||||
case SqlTypes.VECTOR_INT8:
|
||||
case SqlTypes.VECTOR_FLOAT32:
|
||||
case SqlTypes.VECTOR_FLOAT64:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
org.hibernate.vector.PGVectorFunctionContributor
|
||||
org.hibernate.vector.PGVectorFunctionContributor
|
||||
org.hibernate.vector.OracleVectorFunctionContributor
|
|
@ -1 +1,2 @@
|
|||
org.hibernate.vector.PGVectorTypeContributor
|
||||
org.hibernate.vector.PGVectorTypeContributor
|
||||
org.hibernate.vector.OracleVectorTypeContributor
|
|
@ -0,0 +1,280 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
@DomainModel(annotatedClasses = OracleByteVectorTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
|
||||
public class OracleByteVectorTest {
|
||||
|
||||
private static final byte[] V1 = new byte[]{ 1, 2, 3 };
|
||||
private static final byte[] V2 = new byte[]{ 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new byte[]{ 1, 2, 3 }, tableRecord.getTheVector() );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new byte[]{ 4, 5, 6 }, tableRecord.getTheVector() );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final byte[] vector = new byte[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector, byte[].class )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final byte[] vector = new byte[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTaxicabDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::taxicab-distance-example[]
|
||||
final byte[] vector = new byte[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::taxicab-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerProduct(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final byte[] vector = new byte[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHammingDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final byte[] vector = new byte[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDims(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-dims-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-dims-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorNorm(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-norm-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-norm-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
private static double cosineDistance(byte[] f1, byte[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(byte[] f1, byte[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double taxicabDistance(byte[] f1, byte[] f2) {
|
||||
return norm( f1 ) - norm( f2 );
|
||||
}
|
||||
|
||||
private static double innerProduct(byte[] f1, byte[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double hammingDistance(byte[] f1, byte[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
int distance = 0;
|
||||
for (int i = 0; i < f1.length; i++) {
|
||||
if (!(f1[i] == f2[i])) {
|
||||
distance++;
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
|
||||
private static double euclideanNorm(byte[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double norm(byte[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.abs( v );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Entity( name = "VectorEntity" )
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column( name = "the_vector" )
|
||||
@JdbcTypeCode(SqlTypes.VECTOR_INT8)
|
||||
@Array(length = 3)
|
||||
private byte[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, byte[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public byte[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(byte[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,277 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
@DomainModel(annotatedClasses = OracleDoubleVectorTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
|
||||
public class OracleDoubleVectorTest {
|
||||
|
||||
private static final double[] V1 = new double[]{ 1, 2, 3 };
|
||||
private static final double[] V2 = new double[]{ 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new double[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new double[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final double[] vector = new double[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final double[] vector = new double[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.00002D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.00002D);
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTaxicabDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::taxicab-distance-example[]
|
||||
final double[] vector = new double[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::taxicab-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerProduct(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final double[] vector = new double[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHammingDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final double[] vector = new double[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDims(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-dims-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-dims-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorNorm(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-norm-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-norm-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
|
||||
private static double cosineDistance(double[] f1, double[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(double[] f1, double[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double taxicabDistance(double[] f1, double[] f2) {
|
||||
return norm( f1 ) - norm( f2 );
|
||||
}
|
||||
|
||||
public static double hammingDistance(double[] f1, double[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
int distance = 0;
|
||||
for (int i = 0; i < f1.length; i++) {
|
||||
if (!(f1[i] == f2[i])) {
|
||||
distance++;
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
private static double innerProduct(double[] f1, double[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static double euclideanNorm(double[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double norm(double[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.abs( v );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Entity( name = "VectorEntity" )
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column( name = "the_vector" )
|
||||
@JdbcTypeCode(SqlTypes.VECTOR_FLOAT64)
|
||||
@Array(length = 3)
|
||||
private double[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, double[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public double[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(double[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,301 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
@DomainModel(annotatedClasses = OracleFloatVectorTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
|
||||
public class OracleFloatVectorTest {
|
||||
|
||||
private static final float[] V1 = new float[] { 1, 2, 3 };
|
||||
private static final float[] V2 = new float[] { 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector(), 0 );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector(), 0 );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTaxicabDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::taxicab-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::taxicab-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerProduct(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHammingDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDims(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-dims-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.getResultList();
|
||||
//end::vector-dims-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorNorm(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-norm-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.getResultList();
|
||||
//end::vector-norm-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
|
||||
private static double cosineDistance(float[] f1, float[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double taxicabDistance(float[] f1, float[] f2) {
|
||||
return norm( f1 ) - norm( f2 );
|
||||
}
|
||||
|
||||
private static double innerProduct(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double hammingDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
int distance = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
if ( !( f1[i] == f2[i] ) ) {
|
||||
distance++;
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
|
||||
private static double euclideanNorm(float[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double norm(float[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.abs( v );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Entity(name = "VectorEntity")
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column(name = "the_vector")
|
||||
@JdbcTypeCode(SqlTypes.VECTOR_FLOAT32)
|
||||
@Array(length = 3)
|
||||
private float[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, float[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public float[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(float[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,302 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.OracleDialect;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
|
||||
/**
|
||||
* @author Hassan AL Meftah
|
||||
*/
|
||||
@DomainModel(annotatedClasses = OracleGenericVectorTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = OracleDialect.class, matchSubTypes = false, majorVersion = 23, minorVersion = 4)
|
||||
public class OracleGenericVectorTest {
|
||||
|
||||
private static final float[] V1 = new float[] { 1, 2, 3 };
|
||||
private static final float[] V2 = new float[] { 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new float[] { 1, 2, 3 }, tableRecord.getTheVector() );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new float[] { 4, 5, 6 }, tableRecord.getTheVector() );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.0000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.0000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0.000001D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0.000001D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTaxicabDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::taxicab-distance-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::taxicab-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerProduct(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHammingDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final float[] vector = new float[] { 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, hamming_distance(e.theVector, :vec) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V1, vector ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( hammingDistance( V2, vector ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDims(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-dims-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.getResultList();
|
||||
//end::vector-dims-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorNorm(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-norm-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery(
|
||||
"select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id",
|
||||
Tuple.class
|
||||
)
|
||||
.getResultList();
|
||||
//end::vector-norm-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
|
||||
private static double cosineDistance(float[] f1, float[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double taxicabDistance(float[] f1, float[] f2) {
|
||||
return norm( f1 ) - norm( f2 );
|
||||
}
|
||||
|
||||
private static double innerProduct(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public static double hammingDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
int distance = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
if ( !( f1[i] == f2[i] ) ) {
|
||||
distance++;
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
}
|
||||
|
||||
|
||||
private static double euclideanNorm(float[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double norm(float[] f) {
|
||||
double result = 0;
|
||||
for ( double v : f ) {
|
||||
result += Math.abs( v );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Entity(name = "VectorEntity")
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column(name = "the_vector")
|
||||
@JdbcTypeCode(SqlTypes.VECTOR)
|
||||
@Array(length = 3)
|
||||
private float[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, float[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public float[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(float[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue