HHH-15160 - Properly validate the arguments in the distance operators.

This commit is contained in:
Karel Maesen 2023-04-27 21:40:36 +02:00 committed by Christian Beikov
parent 0327531c59
commit 5483f403b1
3 changed files with 42 additions and 6 deletions

View File

@ -11,8 +11,10 @@ import java.util.List;
import org.hibernate.boot.model.FunctionContributions; import org.hibernate.boot.model.FunctionContributions;
import org.hibernate.query.sqm.function.NamedSqmFunctionDescriptor; import org.hibernate.query.sqm.function.NamedSqmFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.ArgumentTypesValidator;
import org.hibernate.query.sqm.produce.function.ArgumentsValidator; import org.hibernate.query.sqm.produce.function.ArgumentsValidator;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver; import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers; import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.spatial.BaseSqmFunctionDescriptors; import org.hibernate.spatial.BaseSqmFunctionDescriptors;
import org.hibernate.spatial.FunctionKey; import org.hibernate.spatial.FunctionKey;
@ -24,7 +26,7 @@ import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.BasicTypeRegistry; import org.hibernate.type.BasicTypeRegistry;
import org.hibernate.type.StandardBasicTypes; import org.hibernate.type.StandardBasicTypes;
import static org.hibernate.query.sqm.produce.function.StandardArgumentsValidators.exactly; import static org.hibernate.query.sqm.produce.function.FunctionParameterType.SPATIAL;
public class PostgisSqmFunctionDescriptors extends BaseSqmFunctionDescriptors { public class PostgisSqmFunctionDescriptors extends BaseSqmFunctionDescriptors {
@ -47,7 +49,7 @@ public class PostgisSqmFunctionDescriptors extends BaseSqmFunctionDescriptors {
new PostgisOperator( new PostgisOperator(
name, name,
operator, operator,
exactly( 2 ), new ArgumentTypesValidator( StandardArgumentsValidators.exactly( 2 ), SPATIAL ),
StandardFunctionReturnTypeResolvers.invariant( typeRegistry.resolve( StandardFunctionReturnTypeResolvers.invariant( typeRegistry.resolve(
StandardBasicTypes.DOUBLE ) StandardBasicTypes.DOUBLE )
) )

View File

@ -33,7 +33,9 @@ import org.geolatte.geom.crs.CoordinateReferenceSystems;
import static org.geolatte.geom.builder.DSL.c; import static org.geolatte.geom.builder.DSL.c;
import static org.geolatte.geom.builder.DSL.point; import static org.geolatte.geom.builder.DSL.point;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
@ -107,11 +109,43 @@ public class PostgisDistanceOperatorsTest {
List<Neighbor> results = query.getResultList(); List<Neighbor> results = query.getResultList();
assertFalse( results.isEmpty() ); assertFalse( results.isEmpty() );
String sql = inspector.getSqlQueries().get( 0 ); String sql = inspector.getSqlQueries().get( 0 );
assertTrue(sql.matches(".*order by.*point\\w*<<->>.*"), "<<->>> operator is not rendered correctly"); assertTrue(
sql.matches( ".*order by.*point\\w*<<->>.*" ),
"<<->>> operator is not rendered correctly"
);
} }
); );
} }
@Test
public void testInvalidArguments(SessionFactoryScope scope) {
SQLStatementInspector inspector = scope.getCollectingStatementInspector();
inspector.clear();
IllegalArgumentException thrown = assertThrows( IllegalArgumentException.class, () ->
scope.inTransaction(
session -> {
TypedQuery<Neighbor> query = session.createQuery(
"select n from Neighbor n order by distance_2d_bbox(n.point, :pnt )",
Neighbor.class
)
.setParameter( "pnt", 130 );
List<Neighbor> results = query.getResultList();
assertFalse( results.isEmpty() );
String sql = inspector.getSqlQueries().get( 0 );
assertTrue(
sql.matches( ".*order by.*point\\w*<#>.*" ),
"<#> operator is not rendered correctly"
);
}
)
);
assertEquals(
"Parameter 1 of function distance_2d_bbox() has type SPATIAL, but argument is of type java.lang.Integer",
thrown.getMessage()
);
}
@AfterEach @AfterEach
public void cleanUp(SessionFactoryScope scope) { public void cleanUp(SessionFactoryScope scope) {
scope.inTransaction( scope.inTransaction(