HHH-17357 Add hibernate-types module with pgvector support

This commit is contained in:
Christian Beikov 2023-10-31 19:47:13 +01:00
parent 029100651c
commit eebb305837
22 changed files with 1178 additions and 89 deletions

View File

@ -147,22 +147,26 @@ postgresql() {
postgresql_12() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:12-3.4
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-12-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}
postgresql_13() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:13-3.1
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-13-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}
postgresql_14() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:14-3.3
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-14-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}
postgresql_15() {
$CONTAINER_CLI rm -f postgres || true
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 --tmpfs /pgtmpfs:size=131072k -d docker.io/postgis/postgis:15-3.3 \
-c fsync=off -c synchronous_commit=off -c full_page_writes=off -c shared_buffers=256MB -c maintenance_work_mem=256MB -c max_wal_size=1GB -c checkpoint_timeout=1d
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-15-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
}
edb() {

View File

@ -34,6 +34,7 @@ include::chapters/query/hql/QueryLanguage.adoc[]
include::chapters/query/criteria/Criteria.adoc[]
include::chapters/query/native/Native.adoc[]
include::chapters/query/spatial/Spatial.adoc[]
include::chapters/query/types/TypesModule.adoc[]
include::chapters/multitenancy/MultiTenancy.adoc[]
include::chapters/envers/Envers.adoc[]
include::chapters/beans/Beans.adoc[]

View File

@ -0,0 +1,179 @@
[[types-module]]
== Hibernate Types module
:root-project-dir: ../../../../../../../..
:types-project-dir: {root-project-dir}/hibernate-types
:example-dir-types: {types-project-dir}/src/test/java/org/hibernate/types
:extrasdir: extras
[[types-module-overview]]
=== Overview
The Hibernate ORM core module tries to be as minimal as possible and only model functionality
that is somewhat "standard" in the SQL space or can only be modeled as part of the core module.
To avoid growing that module further unnecessarily, support for certain special SQL types or functions
is separated out into the Hibernate ORM types module.
[[types-module-setup]]
=== Setup
You need to include the `hibernate-types` dependency in your build environment.
For Maven, you need to add the following dependency:
[[types-module-setup-maven-example]]
.Maven dependency
====
[source,xml]
----
<dependency>
<groupId>org.hibernate.orm</groupId>
<artifactId>hibernate-types</artifactId>
<version>${hibernate.version}</version>
</dependency>
----
====
The module contains service implementations that are picked up by the Java `ServiceLoader` automatically,
so no further configuration is necessary to make the features available.
[[types-module-vector]]
=== Vector type support
The Hibernate ORM types module comes with support for a special `vector` data type that essentially represents an array of floats.
So far, only the PostgreSQL extension `pgvector` is supported, but in theory,
the vector specific functions could be implemented to work with every database that supports arrays.
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation].
[[types-module-vector-usage]]
==== Usage
Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`.
[[types-module-vector-usage-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=usage-example]
----
====
To cast the string representation of a vector to the vector data type, simply use an HQL cast i.e. `cast('[1,2,3]' as vector)`.
[[types-module-vector-functions]]
==== Functions
Expressions of the vector type can be used with various vector functions.
[[types-module-vector-functions-overview]]
|===
| Function | Purpose
| `cosine_distance()` | Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors. Maps to the `<``=``>` operator
| `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator
| `l2_distance()` | Alias for `euclidean_distance()`
| `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors
| `l1_distance()` | Alias for `taxicab_distance()`
| `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors
| `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator
| `vector_dims()` | Determines the dimensions of a vector
| `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector
|===
In addition to these special vector functions, it is also possible to use vectors with the following builtin operators
`<vector1> + <vector2> = <vector3>`:: Element-wise addition of vectors.
`<vector1> - <vector2> = <vector3>`:: Element-wise subtraction of vectors.
`<vector1> * <vector2> = <vector3>`:: Element-wise multiplication of vectors.
`sum(<vector1>) = <vector2>`:: Aggregate function support for element-wise summation of vectors.
`avg(<vector1>) = <vector2>`:: Aggregate function support for element-wise average of vectors.
[[types-module-vector-functions-cosine-distance]]
===== `cosine_distance()`
Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors,
which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) )`. Maps to the `<``=``>` pgvector operator.
[[types-module-vector-functions-cosine-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=cosine-distance-example]
----
====
[[types-module-vector-functions-euclidean-distance]]
===== `euclidean_distance()` and `l2_distance()`
Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors,
which is `sqrt( sum( (v1_i - v2_i)^2 ) )`. Maps to the `<``-``>` pgvector operator.
The `l2_distance()` function is an alias.
[[types-module-vector-functions-euclidean-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=euclidean-distance-example]
----
====
[[types-module-vector-functions-taxicab-distance]]
===== `taxicab_distance()` and `l1_distance()`
Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors,
which is `vector_norm(v1) - vector_norm(v2)`.
The `l1_distance()` function is an alias.
[[types-module-vector-functions-taxicab-distance-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=taxicab-distance-example]
----
====
[[types-module-vector-functions-inner-product]]
===== `inner_product()` and `negative_inner_product()`
Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors,
which is `sum( v1_i * v2_i )`. The `negative_inner_product()` function maps to the `<``#``>` pgvector operator,
and the `inner_product()` function as well, but multiplies the result time `-1`.
[[types-module-vector-functions-inner-product-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=inner-product-example]
----
====
[[types-module-vector-functions-vector-dims]]
===== `vector_dims()`
Determines the dimensions of a vector.
[[types-module-vector-functions-vector-dims-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-dims-example]
----
====
[[types-module-vector-functions-vector-norm]]
===== `vector_norm()`
Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector,
which is `sqrt( sum( v_i^2 ) )`.
[[types-module-vector-functions-vector-norm-example]]
====
[source, JAVA, indent=0]
----
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-norm-example]
----
====

View File

@ -8,16 +8,23 @@ package org.hibernate.dialect.function;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.function.Supplier;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.sql.ast.Clause;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
@ -27,8 +34,14 @@ import org.hibernate.sql.ast.tree.expression.CastTarget;
import org.hibernate.sql.ast.tree.expression.Distinct;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.BasicType;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
import org.hibernate.type.spi.TypeConfiguration;
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
@ -49,10 +62,8 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
super(
"avg",
FunctionKind.AGGREGATE,
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
StandardFunctionReturnTypeResolvers.invariant(
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
),
new Validator(),
new ReturnTypeResolver( typeConfiguration ),
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, NUMERIC )
);
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
@ -131,4 +142,116 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
return "(NUMERIC arg)";
}
public static class Validator implements ArgumentsValidator {
public static final ArgumentsValidator INSTANCE = new Validator();
@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
if ( arguments.size() != 1 ) {
throw new FunctionArgumentException(
String.format(
Locale.ROOT,
"Function %s() has %d parameters, but %d arguments given",
functionName,
1,
arguments.size()
)
);
}
final SqmTypedNode<?> argument = arguments.get( 0 );
final SqmExpressible<?> expressible = argument.getExpressible();
final DomainType<?> domainType;
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
final JdbcType jdbcType = getJdbcType( domainType, typeConfiguration );
if ( !isNumeric( jdbcType ) ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
1,
functionName,
NUMERIC,
domainType.getTypeName()
)
);
}
}
}
private static boolean isNumeric(JdbcType jdbcType) {
final int sqlTypeCode = jdbcType.getDefaultSqlTypeCode();
if ( SqlTypes.isNumericType( sqlTypeCode ) ) {
return true;
}
if ( jdbcType instanceof ArrayJdbcType ) {
return isNumeric( ( (ArrayJdbcType) jdbcType ).getElementJdbcType() );
}
return false;
}
private static JdbcType getJdbcType(DomainType<?> domainType, TypeConfiguration typeConfiguration) {
if ( domainType instanceof JdbcMapping ) {
return ( (JdbcMapping) domainType ).getJdbcType();
}
else {
final JavaType<?> javaType = domainType.getExpressibleJavaType();
if ( javaType.getJavaTypeClass().isEnum() ) {
// we can't tell if the enum is mapped STRING or ORDINAL
return ObjectJdbcType.INSTANCE;
}
else {
return javaType.getRecommendedJdbcType( typeConfiguration.getCurrentBaseSqlTypeIndicators() );
}
}
}
@Override
public String getSignature() {
return "(arg)";
}
}
public static class ReturnTypeResolver implements FunctionReturnTypeResolver {
private final BasicType<Double> doubleType;
public ReturnTypeResolver(TypeConfiguration typeConfiguration) {
this.doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
}
@Override
public BasicValuedMapping resolveFunctionReturnType(
Supplier<BasicValuedMapping> impliedTypeAccess,
List<? extends SqlAstNode> arguments) {
final BasicValuedMapping impliedType = impliedTypeAccess.get();
if ( impliedType != null ) {
return impliedType;
}
final JdbcMapping jdbcMapping = ( (Expression) arguments.get( 0 ) ).getExpressionType().getSingleJdbcMapping();
if ( jdbcMapping instanceof BasicPluralType<?, ?> ) {
return (BasicValuedMapping) jdbcMapping;
}
return doubleType;
}
@Override
public ReturnableType<?> resolveFunctionReturnType(
ReturnableType<?> impliedType,
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
List<? extends SqmTypedNode<?>> arguments,
TypeConfiguration typeConfiguration) {
final SqmExpressible<?> expressible = arguments.get( 0 ).getExpressible();
final DomainType<?> domainType;
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
if ( domainType instanceof BasicPluralType<?, ?> ) {
return (ReturnableType<?>) domainType;
}
}
return typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
}
}
}

View File

@ -2009,11 +2009,11 @@ public class CommonFunctionFactory {
.setExactArgumentCount( 1 )
.register();
functionRegistry.namedAggregateDescriptorBuilder( "avg" )
.setArgumentRenderingMode( inferenceArgumentRenderingMode )
.setInvariantType(doubleType)
.setExactArgumentCount( 1 )
.setParameterTypes(NUMERIC)
.setArgumentsValidator( AvgFunction.Validator.INSTANCE )
.setReturnTypeResolver( new AvgFunction.ReturnTypeResolver( typeConfiguration ) )
.register();
functionRegistry.register(

View File

@ -7,6 +7,7 @@
package org.hibernate.dialect.function;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.query.ReturnableType;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
@ -17,12 +18,12 @@ import org.hibernate.type.spi.TypeConfiguration;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.sql.Types;
import java.util.List;
import java.util.function.Supplier;
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentType;
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentValuedMapping;
import static org.hibernate.type.SqlTypes.*;
/**
* Resolve according to JPA spec 4.8.5
@ -84,18 +85,20 @@ class SumReturnTypeResolver implements FunctionReturnTypeResolver {
}
}
switch ( basicType.getJdbcType().getDefaultSqlTypeCode() ) {
case Types.SMALLINT:
case Types.TINYINT:
case Types.INTEGER:
case Types.BIGINT:
case SMALLINT:
case TINYINT:
case INTEGER:
case BIGINT:
return longType;
case Types.FLOAT:
case Types.REAL:
case Types.DOUBLE:
case FLOAT:
case REAL:
case DOUBLE:
return doubleType;
case Types.DECIMAL:
case Types.NUMERIC:
case DECIMAL:
case NUMERIC:
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
case VECTOR:
return basicType;
}
return bigDecimalType;
}
@ -112,22 +115,23 @@ class SumReturnTypeResolver implements FunctionReturnTypeResolver {
}
// Resolve according to JPA spec 4.8.5
final BasicValuedMapping specifiedArgType = extractArgumentValuedMapping( arguments, 1 );
switch ( specifiedArgType.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
case Types.SMALLINT:
case Types.TINYINT:
case Types.INTEGER:
case Types.BIGINT:
final JdbcMapping jdbcMapping = specifiedArgType.getJdbcMapping();
switch ( jdbcMapping.getJdbcType().getDefaultSqlTypeCode() ) {
case SMALLINT:
case TINYINT:
case INTEGER:
case BIGINT:
return longType;
case Types.FLOAT:
case Types.REAL:
case Types.DOUBLE:
case FLOAT:
case REAL:
case DOUBLE:
return doubleType;
case Types.DECIMAL:
case Types.NUMERIC:
final Class<?> argTypeClass = specifiedArgType.getJdbcMapping()
.getJavaTypeDescriptor()
.getJavaTypeClass();
return BigInteger.class.isAssignableFrom(argTypeClass) ? bigIntegerType : bigDecimalType;
case DECIMAL:
case NUMERIC:
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
case VECTOR:
return (BasicValuedMapping) jdbcMapping;
}
return bigDecimalType;
}

View File

@ -25,6 +25,7 @@ import org.hibernate.query.sqm.tree.domain.SqmPath;
import org.hibernate.query.sqm.tree.domain.SqmPluralValuedSimplePath;
import org.hibernate.query.sqm.tree.expression.SqmExpression;
import org.hibernate.query.sqm.tree.expression.SqmLiteralNull;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.BasicType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
@ -99,7 +100,8 @@ public class TypecheckUtil {
* @see #isTypeAssignable(SqmPathSource, SqmExpressible, SessionFactoryImplementor)
*/
public static boolean areTypesComparable(
SqmExpressible<?> lhsType, SqmExpressible<?> rhsType,
SqmExpressible<?> lhsType,
SqmExpressible<?> rhsType,
SessionFactoryImplementor factory) {
if ( lhsType == null || rhsType == null || lhsType == rhsType ) {
@ -118,7 +120,10 @@ public class TypecheckUtil {
// for embeddables, the embeddable class must match exactly
if ( lhsType instanceof EmbeddedSqmPathSource && rhsType instanceof EmbeddedSqmPathSource ) {
return areEmbeddableTypesComparable( (EmbeddedSqmPathSource<?>) lhsType, (EmbeddedSqmPathSource<?>) rhsType );
return areEmbeddableTypesComparable(
(EmbeddedSqmPathSource<?>) lhsType,
(EmbeddedSqmPathSource<?>) rhsType
);
}
// for tuple constructors, we must check each element
@ -186,12 +191,17 @@ public class TypecheckUtil {
return false;
}
private static boolean areEmbeddableTypesComparable(EmbeddedSqmPathSource<?> lhsType, EmbeddedSqmPathSource<?> rhsType) {
private static boolean areEmbeddableTypesComparable(
EmbeddedSqmPathSource<?> lhsType,
EmbeddedSqmPathSource<?> rhsType) {
// no polymorphism for embeddable types
return rhsType.getNodeJavaType() == lhsType.getNodeJavaType();
}
private static boolean areTupleTypesComparable(SessionFactoryImplementor factory, TupleType<?> lhsTuple, TupleType<?> rhsTuple) {
private static boolean areTupleTypesComparable(
SessionFactoryImplementor factory,
TupleType<?> lhsTuple,
TupleType<?> rhsTuple) {
if ( rhsTuple.componentCount() != lhsTuple.componentCount() ) {
return false;
}
@ -387,16 +397,18 @@ public class TypecheckUtil {
if ( !Number.class.isAssignableFrom( rightJavaType )
// we can scale a duration by a number
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a numeric type (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')"
);
}
break;
default:
if ( !Number.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
);
}
break;
}
@ -408,15 +420,17 @@ public class TypecheckUtil {
case SUBTRACT:
// we can add/subtract durations
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
);
}
break;
default:
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + leftNodeType.getTypeName() +
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
);
}
}
else if ( Temporal.class.isAssignableFrom( leftJavaType )
@ -426,9 +440,10 @@ public class TypecheckUtil {
case ADD:
// we can add a duration to date, time, or datetime
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
);
}
break;
case SUBTRACT:
@ -437,32 +452,57 @@ public class TypecheckUtil {
&& !java.util.Date.class.isAssignableFrom( rightJavaType )
// we can subtract a duration from a date, time, or datetime
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
);
}
break;
default:
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + leftNodeType.getTypeName() +
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
);
}
}
else if ( isNumberArray( leftNodeType ) ) {
// left operand is a number
if ( !isNumberArray( rightNodeType ) ) {
throw new SemanticException(
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
"' which is not a numeric array type" + " (it is not an instance of 'java.lang.Number[]')"
);
}
}
else {
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of " + op.getOperatorSqlText()
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')"
);
}
}
}
private static boolean isNumberArray(SqmExpressible<?> expressible) {
final DomainType<?> domainType;
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
return domainType instanceof BasicPluralType<?, ?> && Number.class.isAssignableFrom(
( (BasicPluralType<?, ?>) domainType ).getElementType().getJavaType()
);
}
return false;
}
public static void assertString(SqmExpression<?> expression) {
final SqmExpressible<?> nodeType = expression.getNodeType();
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( javaType != String.class && javaType != char[].class ) {
throw new SemanticException( "Operand of 'like' is of type '" + nodeType.getTypeName() + "' which is not a string"
+ " (it is not an instance of 'java.lang.String' or 'char[]')" );
throw new SemanticException(
"Operand of 'like' is of type '" + nodeType.getTypeName() +
"' which is not a string (it is not an instance of 'java.lang.String' or 'char[]')"
);
}
}
}
@ -487,8 +527,10 @@ public class TypecheckUtil {
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration"
+ " (it is not an instance of 'java.time.TemporalAmount')" );
throw new SemanticException(
"Operand of 'by' is of type '" + nodeType.getTypeName() +
"' which is not a duration (it is not an instance of 'java.time.TemporalAmount')"
);
}
}
}
@ -498,9 +540,10 @@ public class TypecheckUtil {
if ( nodeType != null ) {
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
if ( !Number.class.isAssignableFrom( javaType ) ) {
throw new SemanticException( "Operand of " + op.getOperatorChar()
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
+ " (it is not an instance of 'java.lang.Number')" );
throw new SemanticException(
"Operand of " + op.getOperatorChar() + " is of type '" + nodeType.getTypeName() +
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
);
}
}
}

View File

@ -218,14 +218,14 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
&& isUnknown( ((BasicType<?>) expressionType).getJavaTypeDescriptor() );
}
private int validateArgument(int count, JdbcMappingContainer expressionType, String functionName) {
private int validateArgument(int paramNumber, JdbcMappingContainer expressionType, String functionName) {
final int jdbcTypeCount = expressionType.getJdbcTypeCount();
for ( int i = 0; i < jdbcTypeCount; i++ ) {
final JdbcMapping mapping = expressionType.getJdbcMapping( i );
FunctionParameterType type = count < types.length ? types[count++] : types[types.length - 1];
if (type != null) {
FunctionParameterType type = paramNumber < types.length ? types[paramNumber++] : types[types.length - 1];
if ( type != null ) {
checkArgumentType(
count,
paramNumber,
functionName,
type,
mapping.getJdbcType().getDefaultSqlTypeCode(),
@ -233,10 +233,10 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
);
}
}
return count;
return paramNumber;
}
private void checkArgumentType(int count, String functionName, FunctionParameterType type, int code, Type javaType) {
private static void checkArgumentType(int paramNumber, String functionName, FunctionParameterType type, int code, Type javaType) {
switch (type) {
case COMPARABLE:
if ( !isCharacterType(code) && !isTemporalType(code) && !isNumericType(code) && !isEnumType( code )
@ -246,63 +246,63 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
// as a special case, we consider a binary column
// comparable when it is mapped by a Java UUID
&& !( javaType == java.util.UUID.class && code == Types.BINARY ) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case STRING:
if ( !isCharacterType(code) && !isEnumType(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case STRING_OR_CLOB:
if ( !isCharacterOrClobType(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case NUMERIC:
if ( !isNumericType(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case INTEGER:
if ( !isIntegral(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case BOOLEAN:
// ugh, need to be careful here, need to accept all the
// JDBC type codes that a Dialect might use for BOOLEAN
if ( code != BOOLEAN && code != BIT && code != TINYINT && code != SMALLINT ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case TEMPORAL:
if ( !isTemporalType(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case DATE:
if ( !hasDatePart(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case TIME:
if ( !hasTimePart(code) ) {
throwError(type, javaType, functionName, count);
throwError(type, javaType, functionName, paramNumber);
}
break;
case SPATIAL:
if ( !isSpatialType( code ) ) {
throwError( type, javaType, functionName, count );
throwError( type, javaType, functionName, paramNumber );
}
}
}
private void throwError(FunctionParameterType type, Type javaType, String functionName, int count) {
private static void throwError(FunctionParameterType type, Type javaType, String functionName, int paramNumber) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
count,
paramNumber,
functionName,
type,
javaType.getTypeName()

View File

@ -388,6 +388,8 @@ import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.converter.spi.BasicValueConverter;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.java.JavaTypeHelper;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators;
import org.hibernate.type.internal.BasicTypeImpl;
import org.hibernate.type.spi.TypeConfiguration;
@ -5861,7 +5863,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
// Only use the inferred mapping as parameter type when the JavaType accepts values of the bind type
if ( inferredJdbcMapping.getMappedJavaType().isWider( paramJdbcMapping.getMappedJavaType() )
// and the bind type is not explicit or the bind type has the same JDBC type
&& ( !bindingTypeExplicit || paramJdbcMapping.getJdbcType() == inferredJdbcMapping.getJdbcType() ) ) {
&& ( !bindingTypeExplicit || canUseInferredType( paramJdbcMapping, inferredJdbcMapping ) ) ) {
return resolveInferredValueMappingForParameter( inferredValueMapping );
}
}
@ -5952,6 +5954,15 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
throw new ConversionException( "Could not determine ValueMapping for SqmParameter: " + sqmParameter );
}
private static boolean canUseInferredType(JdbcMapping bindJdbcMapping, JdbcMapping inferredJdbcMapping) {
final JdbcType bindJdbcType = bindJdbcMapping.getJdbcType();
final JdbcType inferredJdbcType = inferredJdbcMapping.getJdbcType();
// If the bind type has a different JDBC type, we prefer that over the inferred type.
return bindJdbcType == inferredJdbcType
|| bindJdbcType instanceof ArrayJdbcType && inferredJdbcType instanceof ArrayJdbcType
&& ( (ArrayJdbcType) bindJdbcType ).getElementJdbcType() == ( (ArrayJdbcType) inferredJdbcType ).getElementJdbcType();
}
private MappingModelExpressible<?> resolveInferredValueMappingForParameter(MappingModelExpressible<?> inferredValueMapping) {
if ( inferredValueMapping instanceof PluralAttributeMapping ) {
// For parameters, we resolve to the element descriptor
@ -6336,6 +6347,9 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
if ( nodeType instanceof BasicValuedMapping ) {
return (BasicValuedMapping) nodeType;
}
else if ( nodeType.getSqmType() instanceof BasicValuedMapping ) {
return (BasicValuedMapping) nodeType.getSqmType();
}
else {
return getTypeConfiguration().getBasicTypeForJavaType(
nodeType.getExpressibleJavaType().getJavaTypeClass()

View File

@ -212,7 +212,7 @@ public class BasicTypeRegistry implements Serializable {
}
public void register(BasicType<?> type, String key) {
typesByName.put( key, type );
register( type, new String[]{ key } );
}
public void register(BasicType<?> type, String... keys) {

View File

@ -551,6 +551,16 @@ public class SqlTypes {
*/
public static final int NAMED_ENUM = 6001;
/**
* A type code representing an {@code embedding vector} type for databases like
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} that have special extensions.
* An embedding vector essentially is a {@code float[]} with a fixed size.
*
* @since 6.4
*/
public static final int VECTOR = 10_000;
private SqlTypes() {
}

View File

@ -744,6 +744,15 @@ public final class StandardBasicTypes {
);
/**
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector.
*/
public static final BasicTypeReference<float[]> VECTOR = new BasicTypeReference<>(
"vector", float[].class, SqlTypes.VECTOR
);
public static void prime(TypeConfiguration typeConfiguration) {
BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
@ -1236,6 +1245,13 @@ public final class StandardBasicTypes {
"url", java.net.URL.class.getName()
);
handle(
VECTOR,
null,
basicTypeRegistry,
"vector"
);
// Specialized version handlers

View File

@ -0,0 +1,21 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
description = 'Hibernate\'s type extensions'
apply from: rootProject.file( 'gradle/published-java-module.gradle' )
dependencies {
api project( ':hibernate-core' )
testImplementation project( ':hibernate-testing' )
testImplementation project( path: ':hibernate-core', configuration: 'tests' )
}
test {
include '**/**'
}

View File

@ -0,0 +1,98 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.boot.model.FunctionContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.spi.TypeConfiguration;
public class PGVectorFunctionContributor implements FunctionContributor {
@Override
public void contributeFunctions(FunctionContributions functionContributions) {
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final Dialect dialect = functionContributions.getDialect();
if ( dialect instanceof PostgreSQLDialect ) {
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
functionRegistry.patternDescriptorBuilder( "cosine_distance", "?1<=>?2" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "?1<->?2" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
functionRegistry.namedDescriptorBuilder( "l1_distance" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "?1<#>?2" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.patternDescriptorBuilder( "inner_product", "(?1<#>?2)*-1" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 2 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
functionRegistry.namedDescriptorBuilder( "vector_dims" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 1 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
.register();
functionRegistry.namedDescriptorBuilder( "vector_norm" )
.setArgumentsValidator( StandardArgumentsValidators.composite(
StandardArgumentsValidators.exactly( 1 ),
VectorArgumentValidator.INSTANCE
) )
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
.register();
}
}
@Override
public int ordinal() {
return 200;
}
}

View File

@ -0,0 +1,71 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import java.lang.reflect.Type;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.boot.model.TypeContributor;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.engine.jdbc.Size;
import org.hibernate.engine.jdbc.spi.JdbcServices;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.type.BasicArrayType;
import org.hibernate.type.BasicType;
import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
import org.hibernate.type.spi.TypeConfiguration;
public class PGVectorTypeContributor implements TypeContributor {
private static final Type[] VECTOR_JAVA_TYPES = {
Float[].class,
float[].class
};
@Override
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
final Dialect dialect = serviceRegistry.getService( JdbcServices.class ).getDialect();
if ( dialect instanceof PostgreSQLDialect ) {
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
final ArrayJdbcType vectorJdbcType = new VectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) );
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType );
for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) {
basicTypeRegistry.register(
new BasicArrayType<>(
floatBasicType,
vectorJdbcType,
javaTypeRegistry.getDescriptor( vectorJavaType )
),
StandardBasicTypes.VECTOR.getName()
);
}
typeConfiguration.getDdlTypeRegistry().addDescriptor(
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) {
@Override
public String getTypeName(Size size) {
return getTypeName(
size.getArrayLength() == null ? null : size.getArrayLength().longValue(),
null,
null
);
}
}
);
}
}
}

View File

@ -0,0 +1,51 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import java.util.List;
import org.hibernate.metamodel.mapping.MappingModelExpressible;
import org.hibernate.query.sqm.produce.function.FunctionArgumentTypeResolver;
import org.hibernate.query.sqm.sql.SqmToSqlAstConverter;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.query.sqm.tree.expression.SqmExpression;
import org.hibernate.query.sqm.tree.expression.SqmFunction;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.StandardBasicTypes;
/**
* A {@link FunctionArgumentTypeResolver} for {@link SqlTypes#VECTOR} functions.
*/
public class VectorArgumentTypeResolver implements FunctionArgumentTypeResolver {
public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver();
@Override
public MappingModelExpressible<?> resolveFunctionArgumentType(
SqmFunction<?> function,
int argumentIndex,
SqmToSqlAstConverter converter) {
final List<? extends SqmTypedNode<?>> arguments = function.getArguments();
for ( int i = 0; i < arguments.size(); i++ ) {
if ( i != argumentIndex ) {
final SqmTypedNode<?> node = arguments.get( i );
if ( node instanceof SqmExpression<?> ) {
final MappingModelExpressible<?> expressible = converter.determineValueMapping( (SqmExpression<?>) node );
if ( expressible != null ) {
return expressible;
}
}
}
}
return converter.getCreationContext()
.getSessionFactory()
.getTypeConfiguration()
.getBasicTypeRegistry()
.resolve( StandardBasicTypes.VECTOR );
}
}

View File

@ -0,0 +1,52 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import java.util.List;
import org.hibernate.metamodel.model.domain.DomainType;
import org.hibernate.query.sqm.SqmExpressible;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
import org.hibernate.query.sqm.tree.SqmTypedNode;
import org.hibernate.type.BasicPluralType;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.spi.TypeConfiguration;
/**
* A {@link ArgumentsValidator} that validates the arguments are all vector types i.e. {@link org.hibernate.type.SqlTypes#VECTOR}.
*/
public class VectorArgumentValidator implements ArgumentsValidator {
public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator();
@Override
public void validate(
List<? extends SqmTypedNode<?>> arguments,
String functionName,
TypeConfiguration typeConfiguration) {
for ( int i = 0; i < arguments.size(); i++ ) {
final SqmExpressible<?> expressible = arguments.get( i ).getExpressible();
final DomainType<?> type;
if ( expressible != null && ( type = expressible.getSqmType() ) != null && !isVectorType( type ) ) {
throw new FunctionArgumentException(
String.format(
"Parameter %d of function '%s()' requires a vector type, but argument is of type '%s'",
i,
functionName,
type.getTypeName()
)
);
}
}
}
private static boolean isVectorType(SqmExpressible<?> vectorType) {
return vectorType instanceof BasicPluralType<?, ?>
&& ( (BasicPluralType<?, ?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() == SqlTypes.VECTOR;
}
}

View File

@ -0,0 +1,96 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import java.sql.CallableStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.BitSet;
import org.hibernate.dialect.Dialect;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.type.SqlTypes;
import org.hibernate.type.descriptor.ValueExtractor;
import org.hibernate.type.descriptor.WrapperOptions;
import org.hibernate.type.descriptor.java.JavaType;
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
import org.hibernate.type.descriptor.jdbc.JdbcType;
import org.hibernate.type.spi.TypeConfiguration;
public class VectorJdbcType extends ArrayJdbcType {
private static final float[] EMPTY = new float[0];
public VectorJdbcType(JdbcType elementJdbcType) {
super( elementJdbcType );
}
@Override
public int getDefaultSqlTypeCode() {
return SqlTypes.VECTOR;
}
@Override
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
Integer precision,
Integer scale,
TypeConfiguration typeConfiguration) {
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
}
@Override
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
appender.append( "cast(" );
appender.append( writeExpression );
appender.append( " as vector)" );
}
@Override
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
return new BasicExtractor<>( javaTypeDescriptor, this ) {
@Override
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( getFloatArray( rs.getString( paramIndex ) ), options );
}
@Override
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( getFloatArray( statement.getString( index ) ), options );
}
@Override
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
return javaTypeDescriptor.wrap( getFloatArray( statement.getString( name ) ), options );
}
private float[] getFloatArray(String string) {
if ( string.length() == 2 ) {
return EMPTY;
}
final BitSet commaPositions = new BitSet();
int size = 1;
for ( int i = 1; i < string.length(); i++ ) {
final char c = string.charAt( i );
if ( c == ',' ) {
commaPositions.set( i );
size++;
}
}
final float[] result = new float[size];
int floatStartIndex = 1;
int commaIndex;
int index = 0;
while ( ( commaIndex = commaPositions.nextSetBit( floatStartIndex ) ) != -1 ) {
result[index++] = Float.parseFloat( string.substring( floatStartIndex, commaIndex ) );
floatStartIndex = commaIndex + 1;
}
result[index] = Float.parseFloat( string.substring( floatStartIndex, string.length() - 1 ) );
return result;
}
};
}
}

View File

@ -0,0 +1 @@
org.hibernate.types.vector.PGVectorFunctionContributor

View File

@ -0,0 +1 @@
org.hibernate.types.vector.PGVectorTypeContributor

View File

@ -0,0 +1,303 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.types.vector;
import java.util.List;
import org.hibernate.annotations.Array;
import org.hibernate.annotations.JdbcTypeCode;
import org.hibernate.dialect.PostgreSQLDialect;
import org.hibernate.type.SqlTypes;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Tuple;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* @author Christian Beikov
*/
@DomainModel(annotatedClasses = PGVectorTest.VectorEntity.class)
@SessionFactory
@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false)
public class PGVectorTest {
private static final float[] V1 = new float[]{ 1, 2, 3 };
private static final float[] V2 = new float[]{ 4, 5, 6 };
@BeforeEach
public void prepareData(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.persist( new VectorEntity( 1L, V1 ) );
em.persist( new VectorEntity( 2L, V2 ) );
} );
}
@AfterEach
public void cleanup(SessionFactoryScope scope) {
scope.inTransaction( em -> {
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
} );
}
@Test
public void testRead(SessionFactoryScope scope) {
scope.inTransaction( em -> {
VectorEntity tableRecord;
tableRecord = em.find( VectorEntity.class, 1L );
assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
tableRecord = em.find( VectorEntity.class, 2L );
assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
} );
}
@Test
public void testCosineDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::cosine-distance-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::cosine-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D );
} );
}
@Test
public void testEuclideanDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::euclidean-distance-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::euclidean-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
} );
}
@Test
public void testTaxicabDistance(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::taxicab-distance-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::taxicab-distance-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
} );
}
@Test
public void testInnerProduct(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::inner-product-example[]
final float[] vector = new float[]{ 1, 1, 1 };
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
.setParameter( "vec", vector )
.getResultList();
//end::inner-product-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, Double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, Double.class ), 0D );
} );
}
@Test
public void testVectorDims(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-dims-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-dims-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
} );
}
@Test
public void testVectorNorm(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-norm-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-norm-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, Double.class ), 0D );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, Double.class ), 0D );
} );
}
@Test
public void testVectorSum(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-sum-example[]
final List<float[]> results = em.createSelectionQuery( "select sum(e.theVector) from VectorEntity e", float[].class )
.getResultList();
//end::vector-sum-example[]
assertEquals( 1, results.size() );
assertArrayEquals( new float[]{ 5, 7, 9 }, results.get( 0 ) );
} );
}
@Test
public void testVectorAvg(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-avg-example[]
final List<float[]> results = em.createSelectionQuery( "select avg(e.theVector) from VectorEntity e", float[].class )
.getResultList();
//end::vector-avg-example[]
assertEquals( 1, results.size() );
assertArrayEquals( new float[]{ 2.5f, 3.5f, 4.5f }, results.get( 0 ) );
} );
}
@Test
public void testAddition(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-addition-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, e.theVector + cast('[1, 1, 1]' as vector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-addition-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new float[]{ 2, 3, 4 }, results.get( 0 ).get( 1, float[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new float[]{ 5, 6, 7 }, results.get( 1 ).get( 1, float[].class ) );
} );
}
@Test
public void testMultiplication(SessionFactoryScope scope) {
scope.inTransaction( em -> {
//tag::vector-multiplication-example[]
final List<Tuple> results = em.createSelectionQuery( "select e.id, e.theVector * cast('[2, 2, 2]' as vector) from VectorEntity e order by e.id", Tuple.class )
.getResultList();
//end::vector-multiplication-example[]
assertEquals( 2, results.size() );
assertEquals( 1L, results.get( 0 ).get( 0 ) );
assertArrayEquals( new float[]{ 2, 4, 6 }, results.get( 0 ).get( 1, float[].class ) );
assertEquals( 2L, results.get( 1 ).get( 0 ) );
assertArrayEquals( new float[]{ 8, 10, 12 }, results.get( 1 ).get( 1, float[].class ) );
} );
}
private static double cosineDistance(float[] f1, float[] f2) {
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
}
private static double euclideanDistance(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += Math.pow( (double) f1[i] - f2[i], 2 );
}
return Math.sqrt( result );
}
private static double taxicabDistance(float[] f1, float[] f2) {
return norm( f1 ) - norm( f2 );
}
private static double innerProduct(float[] f1, float[] f2) {
assert f1.length == f2.length;
double result = 0;
for ( int i = 0; i < f1.length; i++ ) {
result += ( (double) f1[i] ) * ( (double) f2[i] );
}
return result;
}
private static double euclideanNorm(float[] f) {
double result = 0;
for ( float v : f ) {
result += Math.pow( v, 2 );
}
return Math.sqrt( result );
}
private static double norm(float[] f) {
double result = 0;
for ( float v : f ) {
result += Math.abs( v );
}
return result;
}
@Entity( name = "VectorEntity" )
public static class VectorEntity {
@Id
private Long id;
//tag::usage-example[]
@Column( name = "the_vector" )
@JdbcTypeCode(SqlTypes.VECTOR)
@Array(length = 3)
private float[] theVector;
//end::usage-example[]
public VectorEntity() {
}
public VectorEntity(Long id, float[] theVector) {
this.id = id;
this.theVector = theVector;
}
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public float[] getTheVector() {
return theVector;
}
public void setTheVector(float[] theVector) {
this.theVector = theVector;
}
}
}

View File

@ -300,6 +300,7 @@ include 'hibernate-spatial'
include 'hibernate-platform'
include 'hibernate-community-dialects'
include 'hibernate-types'
include 'hibernate-c3p0'
include 'hibernate-proxool'