HHH-17357 Add hibernate-types module with pgvector support
This commit is contained in:
parent
029100651c
commit
eebb305837
|
@ -147,22 +147,26 @@ postgresql() {
|
|||
postgresql_12() {
|
||||
$CONTAINER_CLI rm -f postgres || true
|
||||
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:12-3.4
|
||||
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-12-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
|
||||
}
|
||||
|
||||
postgresql_13() {
|
||||
$CONTAINER_CLI rm -f postgres || true
|
||||
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:13-3.1
|
||||
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-13-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
|
||||
}
|
||||
|
||||
postgresql_14() {
|
||||
$CONTAINER_CLI rm -f postgres || true
|
||||
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 -d docker.io/postgis/postgis:14-3.3
|
||||
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-14-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
|
||||
}
|
||||
|
||||
postgresql_15() {
|
||||
$CONTAINER_CLI rm -f postgres || true
|
||||
$CONTAINER_CLI run --name postgres -e POSTGRES_USER=hibernate_orm_test -e POSTGRES_PASSWORD=hibernate_orm_test -e POSTGRES_DB=hibernate_orm_test -p5432:5432 --tmpfs /pgtmpfs:size=131072k -d docker.io/postgis/postgis:15-3.3 \
|
||||
-c fsync=off -c synchronous_commit=off -c full_page_writes=off -c shared_buffers=256MB -c maintenance_work_mem=256MB -c max_wal_size=1GB -c checkpoint_timeout=1d
|
||||
$CONTAINER_CLI exec postgres bash -c '/usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -y && apt install postgresql-15-pgvector && psql -U hibernate_orm_test -d hibernate_orm_test -c "create extension vector;"'
|
||||
}
|
||||
|
||||
edb() {
|
||||
|
|
|
@ -34,6 +34,7 @@ include::chapters/query/hql/QueryLanguage.adoc[]
|
|||
include::chapters/query/criteria/Criteria.adoc[]
|
||||
include::chapters/query/native/Native.adoc[]
|
||||
include::chapters/query/spatial/Spatial.adoc[]
|
||||
include::chapters/query/types/TypesModule.adoc[]
|
||||
include::chapters/multitenancy/MultiTenancy.adoc[]
|
||||
include::chapters/envers/Envers.adoc[]
|
||||
include::chapters/beans/Beans.adoc[]
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
[[types-module]]
|
||||
== Hibernate Types module
|
||||
:root-project-dir: ../../../../../../../..
|
||||
:types-project-dir: {root-project-dir}/hibernate-types
|
||||
:example-dir-types: {types-project-dir}/src/test/java/org/hibernate/types
|
||||
:extrasdir: extras
|
||||
|
||||
[[types-module-overview]]
|
||||
=== Overview
|
||||
|
||||
The Hibernate ORM core module tries to be as minimal as possible and only model functionality
|
||||
that is somewhat "standard" in the SQL space or can only be modeled as part of the core module.
|
||||
To avoid growing that module further unnecessarily, support for certain special SQL types or functions
|
||||
is separated out into the Hibernate ORM types module.
|
||||
|
||||
[[types-module-setup]]
|
||||
=== Setup
|
||||
|
||||
You need to include the `hibernate-types` dependency in your build environment.
|
||||
For Maven, you need to add the following dependency:
|
||||
|
||||
[[types-module-setup-maven-example]]
|
||||
.Maven dependency
|
||||
====
|
||||
[source,xml]
|
||||
----
|
||||
<dependency>
|
||||
<groupId>org.hibernate.orm</groupId>
|
||||
<artifactId>hibernate-types</artifactId>
|
||||
<version>${hibernate.version}</version>
|
||||
</dependency>
|
||||
----
|
||||
====
|
||||
|
||||
The module contains service implementations that are picked up by the Java `ServiceLoader` automatically,
|
||||
so no further configuration is necessary to make the features available.
|
||||
|
||||
[[types-module-vector]]
|
||||
=== Vector type support
|
||||
|
||||
The Hibernate ORM types module comes with support for a special `vector` data type that essentially represents an array of floats.
|
||||
|
||||
So far, only the PostgreSQL extension `pgvector` is supported, but in theory,
|
||||
the vector specific functions could be implemented to work with every database that supports arrays.
|
||||
|
||||
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation].
|
||||
|
||||
[[types-module-vector-usage]]
|
||||
==== Usage
|
||||
|
||||
Annotate a persistent attribute with `@JdbcTypeCode(SqlTypes.VECTOR)` and specify the vector length with `@Array(length = ...)`.
|
||||
|
||||
[[types-module-vector-usage-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=usage-example]
|
||||
----
|
||||
====
|
||||
|
||||
To cast the string representation of a vector to the vector data type, simply use an HQL cast i.e. `cast('[1,2,3]' as vector)`.
|
||||
|
||||
[[types-module-vector-functions]]
|
||||
==== Functions
|
||||
|
||||
Expressions of the vector type can be used with various vector functions.
|
||||
|
||||
[[types-module-vector-functions-overview]]
|
||||
|===
|
||||
| Function | Purpose
|
||||
|
||||
| `cosine_distance()` | Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors. Maps to the `<``=``>` operator
|
||||
| `euclidean_distance()` | Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors. Maps to the `<``-``>` operator
|
||||
| `l2_distance()` | Alias for `euclidean_distance()`
|
||||
| `taxicab_distance()` | Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors
|
||||
| `l1_distance()` | Alias for `taxicab_distance()`
|
||||
| `inner_product()` | Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors
|
||||
| `negative_inner_product()` | Computes the negative inner product. Maps to the `<``#``>` operator
|
||||
| `vector_dims()` | Determines the dimensions of a vector
|
||||
| `vector_norm()` | Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector
|
||||
|===
|
||||
|
||||
In addition to these special vector functions, it is also possible to use vectors with the following builtin operators
|
||||
|
||||
`<vector1> + <vector2> = <vector3>`:: Element-wise addition of vectors.
|
||||
`<vector1> - <vector2> = <vector3>`:: Element-wise subtraction of vectors.
|
||||
`<vector1> * <vector2> = <vector3>`:: Element-wise multiplication of vectors.
|
||||
`sum(<vector1>) = <vector2>`:: Aggregate function support for element-wise summation of vectors.
|
||||
`avg(<vector1>) = <vector2>`:: Aggregate function support for element-wise average of vectors.
|
||||
|
||||
[[types-module-vector-functions-cosine-distance]]
|
||||
===== `cosine_distance()`
|
||||
|
||||
Computes the https://en.wikipedia.org/wiki/Cosine_similarity[cosine distance] between two vectors,
|
||||
which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 ) )`. Maps to the `<``=``>` pgvector operator.
|
||||
|
||||
[[types-module-vector-functions-cosine-distance-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=cosine-distance-example]
|
||||
----
|
||||
====
|
||||
|
||||
[[types-module-vector-functions-euclidean-distance]]
|
||||
===== `euclidean_distance()` and `l2_distance()`
|
||||
|
||||
Computes the https://en.wikipedia.org/wiki/Euclidean_distance[euclidean distance] between two vectors,
|
||||
which is `sqrt( sum( (v1_i - v2_i)^2 ) )`. Maps to the `<``-``>` pgvector operator.
|
||||
The `l2_distance()` function is an alias.
|
||||
|
||||
[[types-module-vector-functions-euclidean-distance-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=euclidean-distance-example]
|
||||
----
|
||||
====
|
||||
|
||||
[[types-module-vector-functions-taxicab-distance]]
|
||||
===== `taxicab_distance()` and `l1_distance()`
|
||||
|
||||
Computes the https://en.wikipedia.org/wiki/Taxicab_geometry[taxicab distance] between two vectors,
|
||||
which is `vector_norm(v1) - vector_norm(v2)`.
|
||||
The `l1_distance()` function is an alias.
|
||||
|
||||
[[types-module-vector-functions-taxicab-distance-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=taxicab-distance-example]
|
||||
----
|
||||
====
|
||||
|
||||
[[types-module-vector-functions-inner-product]]
|
||||
===== `inner_product()` and `negative_inner_product()`
|
||||
|
||||
Computes the https://en.wikipedia.org/wiki/Inner_product_space[inner product] between two vectors,
|
||||
which is `sum( v1_i * v2_i )`. The `negative_inner_product()` function maps to the `<``#``>` pgvector operator,
|
||||
and the `inner_product()` function as well, but multiplies the result time `-1`.
|
||||
|
||||
[[types-module-vector-functions-inner-product-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=inner-product-example]
|
||||
----
|
||||
====
|
||||
|
||||
[[types-module-vector-functions-vector-dims]]
|
||||
===== `vector_dims()`
|
||||
|
||||
Determines the dimensions of a vector.
|
||||
|
||||
[[types-module-vector-functions-vector-dims-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-dims-example]
|
||||
----
|
||||
====
|
||||
|
||||
[[types-module-vector-functions-vector-norm]]
|
||||
===== `vector_norm()`
|
||||
|
||||
Computes the https://en.wikipedia.org/wiki/Euclidean_space#Euclidean_norm[Euclidean norm] of a vector,
|
||||
which is `sqrt( sum( v_i^2 ) )`.
|
||||
|
||||
[[types-module-vector-functions-vector-norm-example]]
|
||||
====
|
||||
[source, JAVA, indent=0]
|
||||
----
|
||||
include::{example-dir-types}/vector/PGVectorTest.java[tags=vector-norm-example]
|
||||
----
|
||||
====
|
||||
|
||||
|
||||
|
||||
|
|
@ -8,16 +8,23 @@ package org.hibernate.dialect.function;
|
|||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.metamodel.mapping.BasicValuedMapping;
|
||||
import org.hibernate.metamodel.mapping.JdbcMapping;
|
||||
import org.hibernate.metamodel.mapping.MappingModelExpressible;
|
||||
import org.hibernate.metamodel.model.domain.DomainType;
|
||||
import org.hibernate.query.ReturnableType;
|
||||
import org.hibernate.query.sqm.SqmExpressible;
|
||||
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
|
||||
import org.hibernate.query.sqm.function.FunctionKind;
|
||||
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
|
||||
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
|
||||
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.query.sqm.tree.SqmTypedNode;
|
||||
import org.hibernate.sql.ast.Clause;
|
||||
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
|
||||
import org.hibernate.sql.ast.SqlAstTranslator;
|
||||
|
@ -27,8 +34,14 @@ import org.hibernate.sql.ast.tree.expression.CastTarget;
|
|||
import org.hibernate.sql.ast.tree.expression.Distinct;
|
||||
import org.hibernate.sql.ast.tree.expression.Expression;
|
||||
import org.hibernate.sql.ast.tree.predicate.Predicate;
|
||||
import org.hibernate.type.BasicPluralType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.ObjectJdbcType;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.NUMERIC;
|
||||
|
@ -49,10 +62,8 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
|||
super(
|
||||
"avg",
|
||||
FunctionKind.AGGREGATE,
|
||||
new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 1 ), NUMERIC ),
|
||||
StandardFunctionReturnTypeResolvers.invariant(
|
||||
typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE )
|
||||
),
|
||||
new Validator(),
|
||||
new ReturnTypeResolver( typeConfiguration ),
|
||||
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, NUMERIC )
|
||||
);
|
||||
this.defaultArgumentRenderingMode = defaultArgumentRenderingMode;
|
||||
|
@ -131,4 +142,116 @@ public class AvgFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
|
|||
return "(NUMERIC arg)";
|
||||
}
|
||||
|
||||
public static class Validator implements ArgumentsValidator {
|
||||
|
||||
public static final ArgumentsValidator INSTANCE = new Validator();
|
||||
|
||||
@Override
|
||||
public void validate(
|
||||
List<? extends SqmTypedNode<?>> arguments,
|
||||
String functionName,
|
||||
TypeConfiguration typeConfiguration) {
|
||||
if ( arguments.size() != 1 ) {
|
||||
throw new FunctionArgumentException(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"Function %s() has %d parameters, but %d arguments given",
|
||||
functionName,
|
||||
1,
|
||||
arguments.size()
|
||||
)
|
||||
);
|
||||
}
|
||||
final SqmTypedNode<?> argument = arguments.get( 0 );
|
||||
final SqmExpressible<?> expressible = argument.getExpressible();
|
||||
final DomainType<?> domainType;
|
||||
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
|
||||
final JdbcType jdbcType = getJdbcType( domainType, typeConfiguration );
|
||||
if ( !isNumeric( jdbcType ) ) {
|
||||
throw new FunctionArgumentException(
|
||||
String.format(
|
||||
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
|
||||
1,
|
||||
functionName,
|
||||
NUMERIC,
|
||||
domainType.getTypeName()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isNumeric(JdbcType jdbcType) {
|
||||
final int sqlTypeCode = jdbcType.getDefaultSqlTypeCode();
|
||||
if ( SqlTypes.isNumericType( sqlTypeCode ) ) {
|
||||
return true;
|
||||
}
|
||||
if ( jdbcType instanceof ArrayJdbcType ) {
|
||||
return isNumeric( ( (ArrayJdbcType) jdbcType ).getElementJdbcType() );
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static JdbcType getJdbcType(DomainType<?> domainType, TypeConfiguration typeConfiguration) {
|
||||
if ( domainType instanceof JdbcMapping ) {
|
||||
return ( (JdbcMapping) domainType ).getJdbcType();
|
||||
}
|
||||
else {
|
||||
final JavaType<?> javaType = domainType.getExpressibleJavaType();
|
||||
if ( javaType.getJavaTypeClass().isEnum() ) {
|
||||
// we can't tell if the enum is mapped STRING or ORDINAL
|
||||
return ObjectJdbcType.INSTANCE;
|
||||
}
|
||||
else {
|
||||
return javaType.getRecommendedJdbcType( typeConfiguration.getCurrentBaseSqlTypeIndicators() );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getSignature() {
|
||||
return "(arg)";
|
||||
}
|
||||
}
|
||||
|
||||
public static class ReturnTypeResolver implements FunctionReturnTypeResolver {
|
||||
|
||||
private final BasicType<Double> doubleType;
|
||||
|
||||
public ReturnTypeResolver(TypeConfiguration typeConfiguration) {
|
||||
this.doubleType = typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
|
||||
}
|
||||
|
||||
@Override
|
||||
public BasicValuedMapping resolveFunctionReturnType(
|
||||
Supplier<BasicValuedMapping> impliedTypeAccess,
|
||||
List<? extends SqlAstNode> arguments) {
|
||||
final BasicValuedMapping impliedType = impliedTypeAccess.get();
|
||||
if ( impliedType != null ) {
|
||||
return impliedType;
|
||||
}
|
||||
final JdbcMapping jdbcMapping = ( (Expression) arguments.get( 0 ) ).getExpressionType().getSingleJdbcMapping();
|
||||
if ( jdbcMapping instanceof BasicPluralType<?, ?> ) {
|
||||
return (BasicValuedMapping) jdbcMapping;
|
||||
}
|
||||
return doubleType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ReturnableType<?> resolveFunctionReturnType(
|
||||
ReturnableType<?> impliedType,
|
||||
Supplier<MappingModelExpressible<?>> inferredTypeSupplier,
|
||||
List<? extends SqmTypedNode<?>> arguments,
|
||||
TypeConfiguration typeConfiguration) {
|
||||
final SqmExpressible<?> expressible = arguments.get( 0 ).getExpressible();
|
||||
final DomainType<?> domainType;
|
||||
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
|
||||
if ( domainType instanceof BasicPluralType<?, ?> ) {
|
||||
return (ReturnableType<?>) domainType;
|
||||
}
|
||||
}
|
||||
return typeConfiguration.getBasicTypeRegistry().resolve( StandardBasicTypes.DOUBLE );
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2009,11 +2009,11 @@ public class CommonFunctionFactory {
|
|||
.setExactArgumentCount( 1 )
|
||||
.register();
|
||||
|
||||
|
||||
functionRegistry.namedAggregateDescriptorBuilder( "avg" )
|
||||
.setArgumentRenderingMode( inferenceArgumentRenderingMode )
|
||||
.setInvariantType(doubleType)
|
||||
.setExactArgumentCount( 1 )
|
||||
.setParameterTypes(NUMERIC)
|
||||
.setArgumentsValidator( AvgFunction.Validator.INSTANCE )
|
||||
.setReturnTypeResolver( new AvgFunction.ReturnTypeResolver( typeConfiguration ) )
|
||||
.register();
|
||||
|
||||
functionRegistry.register(
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
package org.hibernate.dialect.function;
|
||||
|
||||
import org.hibernate.metamodel.mapping.BasicValuedMapping;
|
||||
import org.hibernate.metamodel.mapping.JdbcMapping;
|
||||
import org.hibernate.metamodel.mapping.MappingModelExpressible;
|
||||
import org.hibernate.query.ReturnableType;
|
||||
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
|
||||
|
@ -17,12 +18,12 @@ import org.hibernate.type.spi.TypeConfiguration;
|
|||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.BigInteger;
|
||||
import java.sql.Types;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentType;
|
||||
import static org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers.extractArgumentValuedMapping;
|
||||
import static org.hibernate.type.SqlTypes.*;
|
||||
|
||||
/**
|
||||
* Resolve according to JPA spec 4.8.5
|
||||
|
@ -84,18 +85,20 @@ class SumReturnTypeResolver implements FunctionReturnTypeResolver {
|
|||
}
|
||||
}
|
||||
switch ( basicType.getJdbcType().getDefaultSqlTypeCode() ) {
|
||||
case Types.SMALLINT:
|
||||
case Types.TINYINT:
|
||||
case Types.INTEGER:
|
||||
case Types.BIGINT:
|
||||
case SMALLINT:
|
||||
case TINYINT:
|
||||
case INTEGER:
|
||||
case BIGINT:
|
||||
return longType;
|
||||
case Types.FLOAT:
|
||||
case Types.REAL:
|
||||
case Types.DOUBLE:
|
||||
case FLOAT:
|
||||
case REAL:
|
||||
case DOUBLE:
|
||||
return doubleType;
|
||||
case Types.DECIMAL:
|
||||
case Types.NUMERIC:
|
||||
case DECIMAL:
|
||||
case NUMERIC:
|
||||
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
|
||||
case VECTOR:
|
||||
return basicType;
|
||||
}
|
||||
return bigDecimalType;
|
||||
}
|
||||
|
@ -112,22 +115,23 @@ class SumReturnTypeResolver implements FunctionReturnTypeResolver {
|
|||
}
|
||||
// Resolve according to JPA spec 4.8.5
|
||||
final BasicValuedMapping specifiedArgType = extractArgumentValuedMapping( arguments, 1 );
|
||||
switch ( specifiedArgType.getJdbcMapping().getJdbcType().getDefaultSqlTypeCode() ) {
|
||||
case Types.SMALLINT:
|
||||
case Types.TINYINT:
|
||||
case Types.INTEGER:
|
||||
case Types.BIGINT:
|
||||
final JdbcMapping jdbcMapping = specifiedArgType.getJdbcMapping();
|
||||
switch ( jdbcMapping.getJdbcType().getDefaultSqlTypeCode() ) {
|
||||
case SMALLINT:
|
||||
case TINYINT:
|
||||
case INTEGER:
|
||||
case BIGINT:
|
||||
return longType;
|
||||
case Types.FLOAT:
|
||||
case Types.REAL:
|
||||
case Types.DOUBLE:
|
||||
case FLOAT:
|
||||
case REAL:
|
||||
case DOUBLE:
|
||||
return doubleType;
|
||||
case Types.DECIMAL:
|
||||
case Types.NUMERIC:
|
||||
final Class<?> argTypeClass = specifiedArgType.getJdbcMapping()
|
||||
.getJavaTypeDescriptor()
|
||||
.getJavaTypeClass();
|
||||
return BigInteger.class.isAssignableFrom(argTypeClass) ? bigIntegerType : bigDecimalType;
|
||||
case DECIMAL:
|
||||
case NUMERIC:
|
||||
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
|
||||
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
|
||||
case VECTOR:
|
||||
return (BasicValuedMapping) jdbcMapping;
|
||||
}
|
||||
return bigDecimalType;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.hibernate.query.sqm.tree.domain.SqmPath;
|
|||
import org.hibernate.query.sqm.tree.domain.SqmPluralValuedSimplePath;
|
||||
import org.hibernate.query.sqm.tree.expression.SqmExpression;
|
||||
import org.hibernate.query.sqm.tree.expression.SqmLiteralNull;
|
||||
import org.hibernate.type.BasicPluralType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
|
||||
|
@ -99,7 +100,8 @@ public class TypecheckUtil {
|
|||
* @see #isTypeAssignable(SqmPathSource, SqmExpressible, SessionFactoryImplementor)
|
||||
*/
|
||||
public static boolean areTypesComparable(
|
||||
SqmExpressible<?> lhsType, SqmExpressible<?> rhsType,
|
||||
SqmExpressible<?> lhsType,
|
||||
SqmExpressible<?> rhsType,
|
||||
SessionFactoryImplementor factory) {
|
||||
|
||||
if ( lhsType == null || rhsType == null || lhsType == rhsType ) {
|
||||
|
@ -118,7 +120,10 @@ public class TypecheckUtil {
|
|||
// for embeddables, the embeddable class must match exactly
|
||||
|
||||
if ( lhsType instanceof EmbeddedSqmPathSource && rhsType instanceof EmbeddedSqmPathSource ) {
|
||||
return areEmbeddableTypesComparable( (EmbeddedSqmPathSource<?>) lhsType, (EmbeddedSqmPathSource<?>) rhsType );
|
||||
return areEmbeddableTypesComparable(
|
||||
(EmbeddedSqmPathSource<?>) lhsType,
|
||||
(EmbeddedSqmPathSource<?>) rhsType
|
||||
);
|
||||
}
|
||||
|
||||
// for tuple constructors, we must check each element
|
||||
|
@ -186,12 +191,17 @@ public class TypecheckUtil {
|
|||
return false;
|
||||
}
|
||||
|
||||
private static boolean areEmbeddableTypesComparable(EmbeddedSqmPathSource<?> lhsType, EmbeddedSqmPathSource<?> rhsType) {
|
||||
private static boolean areEmbeddableTypesComparable(
|
||||
EmbeddedSqmPathSource<?> lhsType,
|
||||
EmbeddedSqmPathSource<?> rhsType) {
|
||||
// no polymorphism for embeddable types
|
||||
return rhsType.getNodeJavaType() == lhsType.getNodeJavaType();
|
||||
}
|
||||
|
||||
private static boolean areTupleTypesComparable(SessionFactoryImplementor factory, TupleType<?> lhsTuple, TupleType<?> rhsTuple) {
|
||||
private static boolean areTupleTypesComparable(
|
||||
SessionFactoryImplementor factory,
|
||||
TupleType<?> lhsTuple,
|
||||
TupleType<?> rhsTuple) {
|
||||
if ( rhsTuple.componentCount() != lhsTuple.componentCount() ) {
|
||||
return false;
|
||||
}
|
||||
|
@ -387,16 +397,18 @@ public class TypecheckUtil {
|
|||
if ( !Number.class.isAssignableFrom( rightJavaType )
|
||||
// we can scale a duration by a number
|
||||
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a numeric type (it is not an instance of 'java.lang.Number' or 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
if ( !Number.class.isAssignableFrom( rightJavaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -408,15 +420,17 @@ public class TypecheckUtil {
|
|||
case SUBTRACT:
|
||||
// we can add/subtract durations
|
||||
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
|
||||
+ " (it is not an instance of 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + leftNodeType.getTypeName() +
|
||||
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
|
||||
);
|
||||
}
|
||||
}
|
||||
else if ( Temporal.class.isAssignableFrom( leftJavaType )
|
||||
|
@ -426,9 +440,10 @@ public class TypecheckUtil {
|
|||
case ADD:
|
||||
// we can add a duration to date, time, or datetime
|
||||
if ( !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
|
||||
+ " (it is not an instance of 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
break;
|
||||
case SUBTRACT:
|
||||
|
@ -437,32 +452,57 @@ public class TypecheckUtil {
|
|||
&& !java.util.Date.class.isAssignableFrom( rightJavaType )
|
||||
// we can subtract a duration from a date, time, or datetime
|
||||
&& !TemporalAmount.class.isAssignableFrom( rightJavaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + rightNodeType.getTypeName() + "' which is not a temporal amount"
|
||||
+ " (it is not an instance of 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a temporal amount (it is not an instance of 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + leftNodeType.getTypeName() +
|
||||
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
|
||||
);
|
||||
}
|
||||
}
|
||||
else if ( isNumberArray( leftNodeType ) ) {
|
||||
// left operand is a number
|
||||
if ( !isNumberArray( rightNodeType ) ) {
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText() + " is of type '" + rightNodeType.getTypeName() +
|
||||
"' which is not a numeric array type" + " (it is not an instance of 'java.lang.Number[]')"
|
||||
);
|
||||
}
|
||||
}
|
||||
else {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorSqlText()
|
||||
+ " is of type '" + leftNodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number', 'java.time.Temporal', or 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isNumberArray(SqmExpressible<?> expressible) {
|
||||
final DomainType<?> domainType;
|
||||
if ( expressible != null && ( domainType = expressible.getSqmType() ) != null ) {
|
||||
return domainType instanceof BasicPluralType<?, ?> && Number.class.isAssignableFrom(
|
||||
( (BasicPluralType<?, ?>) domainType ).getElementType().getJavaType()
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static void assertString(SqmExpression<?> expression) {
|
||||
final SqmExpressible<?> nodeType = expression.getNodeType();
|
||||
if ( nodeType != null ) {
|
||||
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
|
||||
if ( javaType != String.class && javaType != char[].class ) {
|
||||
throw new SemanticException( "Operand of 'like' is of type '" + nodeType.getTypeName() + "' which is not a string"
|
||||
+ " (it is not an instance of 'java.lang.String' or 'char[]')" );
|
||||
throw new SemanticException(
|
||||
"Operand of 'like' is of type '" + nodeType.getTypeName() +
|
||||
"' which is not a string (it is not an instance of 'java.lang.String' or 'char[]')"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -487,8 +527,10 @@ public class TypecheckUtil {
|
|||
if ( nodeType != null ) {
|
||||
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
|
||||
if ( !TemporalAmount.class.isAssignableFrom( javaType ) ) {
|
||||
throw new SemanticException( "Operand of 'by' is of type '" + nodeType.getTypeName() + "' which is not a duration"
|
||||
+ " (it is not an instance of 'java.time.TemporalAmount')" );
|
||||
throw new SemanticException(
|
||||
"Operand of 'by' is of type '" + nodeType.getTypeName() +
|
||||
"' which is not a duration (it is not an instance of 'java.time.TemporalAmount')"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -498,9 +540,10 @@ public class TypecheckUtil {
|
|||
if ( nodeType != null ) {
|
||||
final Class<?> javaType = nodeType.getExpressibleJavaType().getJavaTypeClass();
|
||||
if ( !Number.class.isAssignableFrom( javaType ) ) {
|
||||
throw new SemanticException( "Operand of " + op.getOperatorChar()
|
||||
+ " is of type '" + nodeType.getTypeName() + "' which is not a numeric type"
|
||||
+ " (it is not an instance of 'java.lang.Number')" );
|
||||
throw new SemanticException(
|
||||
"Operand of " + op.getOperatorChar() + " is of type '" + nodeType.getTypeName() +
|
||||
"' which is not a numeric type (it is not an instance of 'java.lang.Number')"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -218,14 +218,14 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
|
|||
&& isUnknown( ((BasicType<?>) expressionType).getJavaTypeDescriptor() );
|
||||
}
|
||||
|
||||
private int validateArgument(int count, JdbcMappingContainer expressionType, String functionName) {
|
||||
private int validateArgument(int paramNumber, JdbcMappingContainer expressionType, String functionName) {
|
||||
final int jdbcTypeCount = expressionType.getJdbcTypeCount();
|
||||
for ( int i = 0; i < jdbcTypeCount; i++ ) {
|
||||
final JdbcMapping mapping = expressionType.getJdbcMapping( i );
|
||||
FunctionParameterType type = count < types.length ? types[count++] : types[types.length - 1];
|
||||
if (type != null) {
|
||||
FunctionParameterType type = paramNumber < types.length ? types[paramNumber++] : types[types.length - 1];
|
||||
if ( type != null ) {
|
||||
checkArgumentType(
|
||||
count,
|
||||
paramNumber,
|
||||
functionName,
|
||||
type,
|
||||
mapping.getJdbcType().getDefaultSqlTypeCode(),
|
||||
|
@ -233,10 +233,10 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
|
|||
);
|
||||
}
|
||||
}
|
||||
return count;
|
||||
return paramNumber;
|
||||
}
|
||||
|
||||
private void checkArgumentType(int count, String functionName, FunctionParameterType type, int code, Type javaType) {
|
||||
private static void checkArgumentType(int paramNumber, String functionName, FunctionParameterType type, int code, Type javaType) {
|
||||
switch (type) {
|
||||
case COMPARABLE:
|
||||
if ( !isCharacterType(code) && !isTemporalType(code) && !isNumericType(code) && !isEnumType( code )
|
||||
|
@ -246,63 +246,63 @@ public class ArgumentTypesValidator implements ArgumentsValidator {
|
|||
// as a special case, we consider a binary column
|
||||
// comparable when it is mapped by a Java UUID
|
||||
&& !( javaType == java.util.UUID.class && code == Types.BINARY ) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case STRING:
|
||||
if ( !isCharacterType(code) && !isEnumType(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case STRING_OR_CLOB:
|
||||
if ( !isCharacterOrClobType(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case NUMERIC:
|
||||
if ( !isNumericType(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case INTEGER:
|
||||
if ( !isIntegral(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case BOOLEAN:
|
||||
// ugh, need to be careful here, need to accept all the
|
||||
// JDBC type codes that a Dialect might use for BOOLEAN
|
||||
if ( code != BOOLEAN && code != BIT && code != TINYINT && code != SMALLINT ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case TEMPORAL:
|
||||
if ( !isTemporalType(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case DATE:
|
||||
if ( !hasDatePart(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case TIME:
|
||||
if ( !hasTimePart(code) ) {
|
||||
throwError(type, javaType, functionName, count);
|
||||
throwError(type, javaType, functionName, paramNumber);
|
||||
}
|
||||
break;
|
||||
case SPATIAL:
|
||||
if ( !isSpatialType( code ) ) {
|
||||
throwError( type, javaType, functionName, count );
|
||||
throwError( type, javaType, functionName, paramNumber );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void throwError(FunctionParameterType type, Type javaType, String functionName, int count) {
|
||||
private static void throwError(FunctionParameterType type, Type javaType, String functionName, int paramNumber) {
|
||||
throw new FunctionArgumentException(
|
||||
String.format(
|
||||
"Parameter %d of function '%s()' has type '%s', but argument is of type '%s'",
|
||||
count,
|
||||
paramNumber,
|
||||
functionName,
|
||||
type,
|
||||
javaType.getTypeName()
|
||||
|
|
|
@ -388,6 +388,8 @@ import org.hibernate.type.SqlTypes;
|
|||
import org.hibernate.type.descriptor.converter.spi.BasicValueConverter;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.java.JavaTypeHelper;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcTypeIndicators;
|
||||
import org.hibernate.type.internal.BasicTypeImpl;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
@ -5861,7 +5863,7 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
|
|||
// Only use the inferred mapping as parameter type when the JavaType accepts values of the bind type
|
||||
if ( inferredJdbcMapping.getMappedJavaType().isWider( paramJdbcMapping.getMappedJavaType() )
|
||||
// and the bind type is not explicit or the bind type has the same JDBC type
|
||||
&& ( !bindingTypeExplicit || paramJdbcMapping.getJdbcType() == inferredJdbcMapping.getJdbcType() ) ) {
|
||||
&& ( !bindingTypeExplicit || canUseInferredType( paramJdbcMapping, inferredJdbcMapping ) ) ) {
|
||||
return resolveInferredValueMappingForParameter( inferredValueMapping );
|
||||
}
|
||||
}
|
||||
|
@ -5952,6 +5954,15 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
|
|||
throw new ConversionException( "Could not determine ValueMapping for SqmParameter: " + sqmParameter );
|
||||
}
|
||||
|
||||
private static boolean canUseInferredType(JdbcMapping bindJdbcMapping, JdbcMapping inferredJdbcMapping) {
|
||||
final JdbcType bindJdbcType = bindJdbcMapping.getJdbcType();
|
||||
final JdbcType inferredJdbcType = inferredJdbcMapping.getJdbcType();
|
||||
// If the bind type has a different JDBC type, we prefer that over the inferred type.
|
||||
return bindJdbcType == inferredJdbcType
|
||||
|| bindJdbcType instanceof ArrayJdbcType && inferredJdbcType instanceof ArrayJdbcType
|
||||
&& ( (ArrayJdbcType) bindJdbcType ).getElementJdbcType() == ( (ArrayJdbcType) inferredJdbcType ).getElementJdbcType();
|
||||
}
|
||||
|
||||
private MappingModelExpressible<?> resolveInferredValueMappingForParameter(MappingModelExpressible<?> inferredValueMapping) {
|
||||
if ( inferredValueMapping instanceof PluralAttributeMapping ) {
|
||||
// For parameters, we resolve to the element descriptor
|
||||
|
@ -6336,6 +6347,9 @@ public abstract class BaseSqmToSqlAstConverter<T extends Statement> extends Base
|
|||
if ( nodeType instanceof BasicValuedMapping ) {
|
||||
return (BasicValuedMapping) nodeType;
|
||||
}
|
||||
else if ( nodeType.getSqmType() instanceof BasicValuedMapping ) {
|
||||
return (BasicValuedMapping) nodeType.getSqmType();
|
||||
}
|
||||
else {
|
||||
return getTypeConfiguration().getBasicTypeForJavaType(
|
||||
nodeType.getExpressibleJavaType().getJavaTypeClass()
|
||||
|
|
|
@ -212,7 +212,7 @@ public class BasicTypeRegistry implements Serializable {
|
|||
}
|
||||
|
||||
public void register(BasicType<?> type, String key) {
|
||||
typesByName.put( key, type );
|
||||
register( type, new String[]{ key } );
|
||||
}
|
||||
|
||||
public void register(BasicType<?> type, String... keys) {
|
||||
|
|
|
@ -551,6 +551,16 @@ public class SqlTypes {
|
|||
*/
|
||||
public static final int NAMED_ENUM = 6001;
|
||||
|
||||
|
||||
/**
|
||||
* A type code representing an {@code embedding vector} type for databases like
|
||||
* {@link org.hibernate.dialect.PostgreSQLDialect PostgreSQL} that have special extensions.
|
||||
* An embedding vector essentially is a {@code float[]} with a fixed size.
|
||||
*
|
||||
* @since 6.4
|
||||
*/
|
||||
public static final int VECTOR = 10_000;
|
||||
|
||||
private SqlTypes() {
|
||||
}
|
||||
|
||||
|
|
|
@ -744,6 +744,15 @@ public final class StandardBasicTypes {
|
|||
);
|
||||
|
||||
|
||||
/**
|
||||
* The standard Hibernate type for mapping {@code float[]} to JDBC {@link org.hibernate.type.SqlTypes#VECTOR VECTOR},
|
||||
* specifically for embedding vectors like provided by the PostgreSQL extension pgvector.
|
||||
*/
|
||||
public static final BasicTypeReference<float[]> VECTOR = new BasicTypeReference<>(
|
||||
"vector", float[].class, SqlTypes.VECTOR
|
||||
);
|
||||
|
||||
|
||||
public static void prime(TypeConfiguration typeConfiguration) {
|
||||
BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
|
||||
|
@ -1236,6 +1245,13 @@ public final class StandardBasicTypes {
|
|||
"url", java.net.URL.class.getName()
|
||||
);
|
||||
|
||||
handle(
|
||||
VECTOR,
|
||||
null,
|
||||
basicTypeRegistry,
|
||||
"vector"
|
||||
);
|
||||
|
||||
|
||||
// Specialized version handlers
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
|
||||
description = 'Hibernate\'s type extensions'
|
||||
|
||||
apply from: rootProject.file( 'gradle/published-java-module.gradle' )
|
||||
|
||||
dependencies {
|
||||
api project( ':hibernate-core' )
|
||||
|
||||
testImplementation project( ':hibernate-testing' )
|
||||
testImplementation project( path: ':hibernate-core', configuration: 'tests' )
|
||||
}
|
||||
|
||||
test {
|
||||
include '**/**'
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import org.hibernate.boot.model.FunctionContributions;
|
||||
import org.hibernate.boot.model.FunctionContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.PostgreSQLDialect;
|
||||
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
|
||||
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
|
||||
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class PGVectorFunctionContributor implements FunctionContributor {
|
||||
|
||||
@Override
|
||||
public void contributeFunctions(FunctionContributions functionContributions) {
|
||||
final SqmFunctionRegistry functionRegistry = functionContributions.getFunctionRegistry();
|
||||
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
final Dialect dialect = functionContributions.getDialect();
|
||||
if ( dialect instanceof PostgreSQLDialect ) {
|
||||
final BasicType<Double> doubleType = basicTypeRegistry.resolve( StandardBasicTypes.DOUBLE );
|
||||
final BasicType<Integer> integerType = basicTypeRegistry.resolve( StandardBasicTypes.INTEGER );
|
||||
functionRegistry.patternDescriptorBuilder( "cosine_distance", "?1<=>?2" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "euclidean_distance", "?1<->?2" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.registerAlternateKey( "l2_distance", "euclidean_distance" );
|
||||
functionRegistry.namedDescriptorBuilder( "l1_distance" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.registerAlternateKey( "taxicab_distance", "l1_distance" );
|
||||
|
||||
functionRegistry.patternDescriptorBuilder( "negative_inner_product", "?1<#>?2" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.patternDescriptorBuilder( "inner_product", "(?1<#>?2)*-1" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 2 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
functionRegistry.namedDescriptorBuilder( "vector_dims" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 1 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( integerType ) )
|
||||
.register();
|
||||
functionRegistry.namedDescriptorBuilder( "vector_norm" )
|
||||
.setArgumentsValidator( StandardArgumentsValidators.composite(
|
||||
StandardArgumentsValidators.exactly( 1 ),
|
||||
VectorArgumentValidator.INSTANCE
|
||||
) )
|
||||
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
|
||||
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.invariant( doubleType ) )
|
||||
.register();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordinal() {
|
||||
return 200;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import java.lang.reflect.Type;
|
||||
|
||||
import org.hibernate.boot.model.TypeContributions;
|
||||
import org.hibernate.boot.model.TypeContributor;
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.dialect.PostgreSQLDialect;
|
||||
import org.hibernate.engine.jdbc.Size;
|
||||
import org.hibernate.engine.jdbc.spi.JdbcServices;
|
||||
import org.hibernate.service.ServiceRegistry;
|
||||
import org.hibernate.type.BasicArrayType;
|
||||
import org.hibernate.type.BasicType;
|
||||
import org.hibernate.type.BasicTypeRegistry;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
|
||||
import org.hibernate.type.descriptor.sql.internal.DdlTypeImpl;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class PGVectorTypeContributor implements TypeContributor {
|
||||
|
||||
private static final Type[] VECTOR_JAVA_TYPES = {
|
||||
Float[].class,
|
||||
float[].class
|
||||
};
|
||||
|
||||
@Override
|
||||
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
|
||||
final Dialect dialect = serviceRegistry.getService( JdbcServices.class ).getDialect();
|
||||
if ( dialect instanceof PostgreSQLDialect ) {
|
||||
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
|
||||
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
|
||||
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
|
||||
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
|
||||
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
|
||||
final ArrayJdbcType vectorJdbcType = new VectorJdbcType( jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ) );
|
||||
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, vectorJdbcType );
|
||||
for ( Type vectorJavaType : VECTOR_JAVA_TYPES ) {
|
||||
basicTypeRegistry.register(
|
||||
new BasicArrayType<>(
|
||||
floatBasicType,
|
||||
vectorJdbcType,
|
||||
javaTypeRegistry.getDescriptor( vectorJavaType )
|
||||
),
|
||||
StandardBasicTypes.VECTOR.getName()
|
||||
);
|
||||
}
|
||||
typeConfiguration.getDdlTypeRegistry().addDescriptor(
|
||||
new DdlTypeImpl( SqlTypes.VECTOR, "vector($l)", "vector", dialect ) {
|
||||
@Override
|
||||
public String getTypeName(Size size) {
|
||||
return getTypeName(
|
||||
size.getArrayLength() == null ? null : size.getArrayLength().longValue(),
|
||||
null,
|
||||
null
|
||||
);
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.metamodel.mapping.MappingModelExpressible;
|
||||
import org.hibernate.query.sqm.produce.function.FunctionArgumentTypeResolver;
|
||||
import org.hibernate.query.sqm.sql.SqmToSqlAstConverter;
|
||||
import org.hibernate.query.sqm.tree.SqmTypedNode;
|
||||
import org.hibernate.query.sqm.tree.expression.SqmExpression;
|
||||
import org.hibernate.query.sqm.tree.expression.SqmFunction;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.StandardBasicTypes;
|
||||
|
||||
/**
|
||||
* A {@link FunctionArgumentTypeResolver} for {@link SqlTypes#VECTOR} functions.
|
||||
*/
|
||||
public class VectorArgumentTypeResolver implements FunctionArgumentTypeResolver {
|
||||
|
||||
public static final FunctionArgumentTypeResolver INSTANCE = new VectorArgumentTypeResolver();
|
||||
|
||||
@Override
|
||||
public MappingModelExpressible<?> resolveFunctionArgumentType(
|
||||
SqmFunction<?> function,
|
||||
int argumentIndex,
|
||||
SqmToSqlAstConverter converter) {
|
||||
final List<? extends SqmTypedNode<?>> arguments = function.getArguments();
|
||||
for ( int i = 0; i < arguments.size(); i++ ) {
|
||||
if ( i != argumentIndex ) {
|
||||
final SqmTypedNode<?> node = arguments.get( i );
|
||||
if ( node instanceof SqmExpression<?> ) {
|
||||
final MappingModelExpressible<?> expressible = converter.determineValueMapping( (SqmExpression<?>) node );
|
||||
if ( expressible != null ) {
|
||||
return expressible;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return converter.getCreationContext()
|
||||
.getSessionFactory()
|
||||
.getTypeConfiguration()
|
||||
.getBasicTypeRegistry()
|
||||
.resolve( StandardBasicTypes.VECTOR );
|
||||
}
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.metamodel.model.domain.DomainType;
|
||||
import org.hibernate.query.sqm.SqmExpressible;
|
||||
import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
|
||||
import org.hibernate.query.sqm.produce.function.FunctionArgumentException;
|
||||
import org.hibernate.query.sqm.tree.SqmTypedNode;
|
||||
import org.hibernate.type.BasicPluralType;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
/**
|
||||
* A {@link ArgumentsValidator} that validates the arguments are all vector types i.e. {@link org.hibernate.type.SqlTypes#VECTOR}.
|
||||
*/
|
||||
public class VectorArgumentValidator implements ArgumentsValidator {
|
||||
|
||||
public static final ArgumentsValidator INSTANCE = new VectorArgumentValidator();
|
||||
|
||||
@Override
|
||||
public void validate(
|
||||
List<? extends SqmTypedNode<?>> arguments,
|
||||
String functionName,
|
||||
TypeConfiguration typeConfiguration) {
|
||||
for ( int i = 0; i < arguments.size(); i++ ) {
|
||||
final SqmExpressible<?> expressible = arguments.get( i ).getExpressible();
|
||||
final DomainType<?> type;
|
||||
if ( expressible != null && ( type = expressible.getSqmType() ) != null && !isVectorType( type ) ) {
|
||||
throw new FunctionArgumentException(
|
||||
String.format(
|
||||
"Parameter %d of function '%s()' requires a vector type, but argument is of type '%s'",
|
||||
i,
|
||||
functionName,
|
||||
type.getTypeName()
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isVectorType(SqmExpressible<?> vectorType) {
|
||||
return vectorType instanceof BasicPluralType<?, ?>
|
||||
&& ( (BasicPluralType<?, ?>) vectorType ).getJdbcType().getDefaultSqlTypeCode() == SqlTypes.VECTOR;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import java.sql.CallableStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.BitSet;
|
||||
|
||||
import org.hibernate.dialect.Dialect;
|
||||
import org.hibernate.sql.ast.spi.SqlAppender;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
import org.hibernate.type.descriptor.ValueExtractor;
|
||||
import org.hibernate.type.descriptor.WrapperOptions;
|
||||
import org.hibernate.type.descriptor.java.JavaType;
|
||||
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
|
||||
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
|
||||
import org.hibernate.type.descriptor.jdbc.JdbcType;
|
||||
import org.hibernate.type.spi.TypeConfiguration;
|
||||
|
||||
public class VectorJdbcType extends ArrayJdbcType {
|
||||
|
||||
private static final float[] EMPTY = new float[0];
|
||||
public VectorJdbcType(JdbcType elementJdbcType) {
|
||||
super( elementJdbcType );
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getDefaultSqlTypeCode() {
|
||||
return SqlTypes.VECTOR;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
|
||||
Integer precision,
|
||||
Integer scale,
|
||||
TypeConfiguration typeConfiguration) {
|
||||
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
|
||||
}
|
||||
|
||||
@Override
|
||||
public void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect) {
|
||||
appender.append( "cast(" );
|
||||
appender.append( writeExpression );
|
||||
appender.append( " as vector)" );
|
||||
}
|
||||
|
||||
@Override
|
||||
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
|
||||
return new BasicExtractor<>( javaTypeDescriptor, this ) {
|
||||
@Override
|
||||
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( getFloatArray( rs.getString( paramIndex ) ), options );
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( getFloatArray( statement.getString( index ) ), options );
|
||||
}
|
||||
|
||||
@Override
|
||||
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
|
||||
return javaTypeDescriptor.wrap( getFloatArray( statement.getString( name ) ), options );
|
||||
}
|
||||
|
||||
private float[] getFloatArray(String string) {
|
||||
if ( string.length() == 2 ) {
|
||||
return EMPTY;
|
||||
}
|
||||
final BitSet commaPositions = new BitSet();
|
||||
int size = 1;
|
||||
for ( int i = 1; i < string.length(); i++ ) {
|
||||
final char c = string.charAt( i );
|
||||
if ( c == ',' ) {
|
||||
commaPositions.set( i );
|
||||
size++;
|
||||
}
|
||||
}
|
||||
final float[] result = new float[size];
|
||||
int floatStartIndex = 1;
|
||||
int commaIndex;
|
||||
int index = 0;
|
||||
while ( ( commaIndex = commaPositions.nextSetBit( floatStartIndex ) ) != -1 ) {
|
||||
result[index++] = Float.parseFloat( string.substring( floatStartIndex, commaIndex ) );
|
||||
floatStartIndex = commaIndex + 1;
|
||||
}
|
||||
result[index] = Float.parseFloat( string.substring( floatStartIndex, string.length() - 1 ) );
|
||||
return result;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
org.hibernate.types.vector.PGVectorFunctionContributor
|
|
@ -0,0 +1 @@
|
|||
org.hibernate.types.vector.PGVectorTypeContributor
|
|
@ -0,0 +1,303 @@
|
|||
/*
|
||||
* Hibernate, Relational Persistence for Idiomatic Java
|
||||
*
|
||||
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
|
||||
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
|
||||
*/
|
||||
package org.hibernate.types.vector;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.hibernate.annotations.Array;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.dialect.PostgreSQLDialect;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import org.hibernate.testing.orm.junit.DomainModel;
|
||||
import org.hibernate.testing.orm.junit.RequiresDialect;
|
||||
import org.hibernate.testing.orm.junit.SessionFactory;
|
||||
import org.hibernate.testing.orm.junit.SessionFactoryScope;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import jakarta.persistence.Column;
|
||||
import jakarta.persistence.Entity;
|
||||
import jakarta.persistence.Id;
|
||||
import jakarta.persistence.Tuple;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
/**
|
||||
* @author Christian Beikov
|
||||
*/
|
||||
@DomainModel(annotatedClasses = PGVectorTest.VectorEntity.class)
|
||||
@SessionFactory
|
||||
@RequiresDialect(value = PostgreSQLDialect.class, matchSubTypes = false)
|
||||
public class PGVectorTest {
|
||||
|
||||
private static final float[] V1 = new float[]{ 1, 2, 3 };
|
||||
private static final float[] V2 = new float[]{ 4, 5, 6 };
|
||||
|
||||
@BeforeEach
|
||||
public void prepareData(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.persist( new VectorEntity( 1L, V1 ) );
|
||||
em.persist( new VectorEntity( 2L, V2 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
public void cleanup(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
em.createMutationQuery( "delete from VectorEntity" ).executeUpdate();
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRead(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
VectorEntity tableRecord;
|
||||
tableRecord = em.find( VectorEntity.class, 1L );
|
||||
assertArrayEquals( new float[]{ 1, 2, 3 }, tableRecord.getTheVector(), 0 );
|
||||
|
||||
tableRecord = em.find( VectorEntity.class, 2L );
|
||||
assertArrayEquals( new float[]{ 4, 5, 6 }, tableRecord.getTheVector(), 0 );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCosineDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::cosine-distance-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, cosine_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::cosine-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0.0000000000000002D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( cosineDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0.0000000000000002D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEuclideanDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::euclidean-distance-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, euclidean_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::euclidean-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTaxicabDistance(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::taxicab-distance-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, taxicab_distance(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::taxicab-distance-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( taxicabDistance( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInnerProduct(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::inner-product-example[]
|
||||
final float[] vector = new float[]{ 1, 1, 1 };
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, inner_product(e.theVector, :vec), negative_inner_product(e.theVector, :vec) from VectorEntity e order by e.id", Tuple.class )
|
||||
.setParameter( "vec", vector )
|
||||
.getResultList();
|
||||
//end::inner-product-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V1, vector ), results.get( 0 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( innerProduct( V1, vector ) * -1, results.get( 0 ).get( 2, Double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( innerProduct( V2, vector ), results.get( 1 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( innerProduct( V2, vector ) * -1, results.get( 1 ).get( 2, Double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorDims(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-dims-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_dims(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-dims-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( V1.length, results.get( 0 ).get( 1 ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( V2.length, results.get( 1 ).get( 1 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorNorm(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-norm-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, vector_norm(e.theVector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-norm-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V1 ), results.get( 0 ).get( 1, Double.class ), 0D );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertEquals( euclideanNorm( V2 ), results.get( 1 ).get( 1, Double.class ), 0D );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorSum(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-sum-example[]
|
||||
final List<float[]> results = em.createSelectionQuery( "select sum(e.theVector) from VectorEntity e", float[].class )
|
||||
.getResultList();
|
||||
//end::vector-sum-example[]
|
||||
assertEquals( 1, results.size() );
|
||||
assertArrayEquals( new float[]{ 5, 7, 9 }, results.get( 0 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorAvg(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-avg-example[]
|
||||
final List<float[]> results = em.createSelectionQuery( "select avg(e.theVector) from VectorEntity e", float[].class )
|
||||
.getResultList();
|
||||
//end::vector-avg-example[]
|
||||
assertEquals( 1, results.size() );
|
||||
assertArrayEquals( new float[]{ 2.5f, 3.5f, 4.5f }, results.get( 0 ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddition(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-addition-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, e.theVector + cast('[1, 1, 1]' as vector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-addition-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertArrayEquals( new float[]{ 2, 3, 4 }, results.get( 0 ).get( 1, float[].class ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertArrayEquals( new float[]{ 5, 6, 7 }, results.get( 1 ).get( 1, float[].class ) );
|
||||
} );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiplication(SessionFactoryScope scope) {
|
||||
scope.inTransaction( em -> {
|
||||
//tag::vector-multiplication-example[]
|
||||
final List<Tuple> results = em.createSelectionQuery( "select e.id, e.theVector * cast('[2, 2, 2]' as vector) from VectorEntity e order by e.id", Tuple.class )
|
||||
.getResultList();
|
||||
//end::vector-multiplication-example[]
|
||||
assertEquals( 2, results.size() );
|
||||
assertEquals( 1L, results.get( 0 ).get( 0 ) );
|
||||
assertArrayEquals( new float[]{ 2, 4, 6 }, results.get( 0 ).get( 1, float[].class ) );
|
||||
assertEquals( 2L, results.get( 1 ).get( 0 ) );
|
||||
assertArrayEquals( new float[]{ 8, 10, 12 }, results.get( 1 ).get( 1, float[].class ) );
|
||||
} );
|
||||
}
|
||||
|
||||
private static double cosineDistance(float[] f1, float[] f2) {
|
||||
return 1D - innerProduct( f1, f2 ) / ( euclideanNorm( f1 ) * euclideanNorm( f2 ) );
|
||||
}
|
||||
|
||||
private static double euclideanDistance(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += Math.pow( (double) f1[i] - f2[i], 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double taxicabDistance(float[] f1, float[] f2) {
|
||||
return norm( f1 ) - norm( f2 );
|
||||
}
|
||||
|
||||
private static double innerProduct(float[] f1, float[] f2) {
|
||||
assert f1.length == f2.length;
|
||||
double result = 0;
|
||||
for ( int i = 0; i < f1.length; i++ ) {
|
||||
result += ( (double) f1[i] ) * ( (double) f2[i] );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static double euclideanNorm(float[] f) {
|
||||
double result = 0;
|
||||
for ( float v : f ) {
|
||||
result += Math.pow( v, 2 );
|
||||
}
|
||||
return Math.sqrt( result );
|
||||
}
|
||||
|
||||
private static double norm(float[] f) {
|
||||
double result = 0;
|
||||
for ( float v : f ) {
|
||||
result += Math.abs( v );
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Entity( name = "VectorEntity" )
|
||||
public static class VectorEntity {
|
||||
|
||||
@Id
|
||||
private Long id;
|
||||
|
||||
//tag::usage-example[]
|
||||
@Column( name = "the_vector" )
|
||||
@JdbcTypeCode(SqlTypes.VECTOR)
|
||||
@Array(length = 3)
|
||||
private float[] theVector;
|
||||
//end::usage-example[]
|
||||
|
||||
public VectorEntity() {
|
||||
}
|
||||
|
||||
public VectorEntity(Long id, float[] theVector) {
|
||||
this.id = id;
|
||||
this.theVector = theVector;
|
||||
}
|
||||
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public float[] getTheVector() {
|
||||
return theVector;
|
||||
}
|
||||
|
||||
public void setTheVector(float[] theVector) {
|
||||
this.theVector = theVector;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -300,6 +300,7 @@ include 'hibernate-spatial'
|
|||
include 'hibernate-platform'
|
||||
|
||||
include 'hibernate-community-dialects'
|
||||
include 'hibernate-types'
|
||||
|
||||
include 'hibernate-c3p0'
|
||||
include 'hibernate-proxool'
|
||||
|
|
Loading…
Reference in New Issue