HHH-14154 Incorrect SQL generated from Criteria API when concat() and function() methods are used together

This commit is contained in:
Nathan Xu 2020-08-19 18:42:25 -04:00 committed by Sanne Grinovero
parent 4716262645
commit 4b97be39db
2 changed files with 77 additions and 9 deletions

View File

@ -16,15 +16,12 @@ import org.hibernate.QueryException;
import org.hibernate.dialect.function.SQLFunction;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.hql.internal.antlr.SqlGeneratorBase;
import org.hibernate.hql.internal.antlr.SqlTokenTypes;
import org.hibernate.hql.internal.ast.tree.CollectionPathNode;
import org.hibernate.hql.internal.ast.tree.CollectionSizeNode;
import org.hibernate.hql.internal.ast.tree.FromElement;
import org.hibernate.hql.internal.ast.tree.FunctionNode;
import org.hibernate.hql.internal.ast.tree.Node;
import org.hibernate.hql.internal.ast.tree.ParameterContainer;
import org.hibernate.hql.internal.ast.tree.ParameterNode;
import org.hibernate.hql.internal.ast.util.ASTPrinter;
import org.hibernate.hql.internal.ast.util.TokenPrinters;
import org.hibernate.internal.CoreLogging;
import org.hibernate.internal.CoreMessageLogger;
@ -34,7 +31,6 @@ import org.hibernate.param.ParameterSpecification;
import org.hibernate.type.Type;
import antlr.RecognitionException;
import antlr.SemanticException;
import antlr.collections.AST;
/**
@ -192,13 +188,16 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
// METHOD_NAME
FunctionNode functionNode = (FunctionNode) node;
SQLFunction sqlFunction = functionNode.getSQLFunction();
outputStack.addFirst( writer );
if ( sqlFunction == null ) {
// if SQLFunction is null we just write the function out as it appears in the hql statement
writer = new StandardFunctionArguments();
super.beginFunctionTemplate( node, nameNode );
}
else {
// this function has a registered SQLFunction -> redirect output and catch the arguments
outputStack.addFirst( writer );
if ( node.getType() == CAST ) {
writer = new CastFunctionArguments();
}
@ -212,13 +211,17 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
protected void endFunctionTemplate(AST node) {
FunctionNode functionNode = (FunctionNode) node;
SQLFunction sqlFunction = functionNode.getSQLFunction();
final FunctionArgumentsCollectingWriter functionArguments = (FunctionArgumentsCollectingWriter) writer;
if ( sqlFunction == null ) {
super.endFunctionTemplate( node );
writer = outputStack.removeFirst();
out( StringHelper.join( ",", functionArguments.getArgs().iterator() ) );
}
else {
final Type functionType = functionNode.getFirstArgumentType();
// this function has a registered SQLFunction -> redirect output and catch the arguments
FunctionArgumentsCollectingWriter functionArguments = (FunctionArgumentsCollectingWriter) writer;
writer = outputStack.removeFirst();
out( sqlFunction.render( functionType, functionArguments.getArgs(), sessionFactory ) );
}
@ -236,7 +239,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
}
interface FunctionArgumentsCollectingWriter extends SqlWriter {
public List getArgs();
List<String> getArgs();
}
/**
@ -262,7 +265,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
++argInd;
}
public List getArgs() {
public List<String> getArgs() {
return args;
}
}
@ -305,7 +308,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
startedType = true;
}
public List getArgs() {
public List<String> getArgs() {
List<String> rtn = CollectionHelper.arrayList( 2 );
rtn.add( castExpression );
rtn.add( castTargetType );

View File

@ -0,0 +1,65 @@
package org.hibernate.query.hhh14154;
import java.util.Date;
import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.Table;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Root;
import org.hibernate.dialect.H2Dialect;
import org.hibernate.testing.RequiresDialect;
import org.hibernate.testing.TestForIssue;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.junit.Test;
import static org.hibernate.testing.transaction.TransactionUtil.doInJPA;
/**
* @author Archie Cobbs
* @author Nathan Xu
*/
@RequiresDialect( H2Dialect.class )
@TestForIssue( jiraKey = "HHH-14154" )
public class HHH14154Test extends BaseCoreFunctionalTestCase {
@Override
protected Class<?>[] getAnnotatedClasses() {
return new Class<?>[] { HHH14154Test.Foo.class };
}
@Test
public void testNoExceptionThrown() {
doInJPA( this::sessionFactory, em -> {
final CriteriaBuilder cb = em.getCriteriaBuilder();
final CriteriaQuery<Foo> cq = cb.createQuery( Foo.class );
final Root<Foo> foo = cq.from( Foo.class );
cq.select(foo)
.where(
cb.lessThanOrEqualTo(
cb.concat(
cb.function( "FORMATDATETIME", String.class, foo.get( "startTime" ), cb.literal( "HH:mm:ss" ) ),
""
),
"17:00:00"
)
);
em.createQuery( cq ).getResultList();
} );
}
@Entity(name = "Foo")
@Table(name = "Foo")
public static class Foo {
@Id
private long id;
private Date startTime;
}
}