HHH-9100 - Improve CAST function support

This commit is contained in:
Steve Ebersole 2014-04-03 13:02:30 -05:00
parent efa5dc2f6c
commit 6ffbf8f3f3
9 changed files with 325 additions and 9 deletions

View File

@ -204,6 +204,8 @@ tokens
protected void processFunction(AST functionCall,boolean inSelect) throws SemanticException { } protected void processFunction(AST functionCall,boolean inSelect) throws SemanticException { }
protected void processCastFunction(AST functionCall,boolean inSelect) throws SemanticException { }
protected void processAggregation(AST node, boolean inSelect) throws SemanticException { } protected void processAggregation(AST node, boolean inSelect) throws SemanticException { }
protected void processConstructor(AST constructor) throws SemanticException { } protected void processConstructor(AST constructor) throws SemanticException { }
@ -626,6 +628,10 @@ functionCall
: #(METHOD_CALL {inFunctionCall=true;} pathAsIdent ( #(EXPR_LIST (exprOrSubquery)* ) )? ) { : #(METHOD_CALL {inFunctionCall=true;} pathAsIdent ( #(EXPR_LIST (exprOrSubquery)* ) )? ) {
processFunction( #functionCall, inSelect ); processFunction( #functionCall, inSelect );
inFunctionCall=false; inFunctionCall=false;
}
| #(CAST {inFunctionCall=true;} expr pathAsIdent) {
processCastFunction( #functionCall, inSelect );
inFunctionCall=false;
} }
| #(AGGREGATE aggregateExpr ) | #(AGGREGATE aggregateExpr )
; ;

View File

@ -84,7 +84,7 @@ tokens
// -- SQL tokens -- // -- SQL tokens --
// These aren't part of HQL, but the SQL fragment parser uses the HQL lexer, so they need to be declared here. // These aren't part of HQL, but the SQL fragment parser uses the HQL lexer, so they need to be declared here.
CASE="case"; CASE="case"; // a "searched case statement", whereas CASE2 represents a "simple case statement"
END="end"; END="end";
ELSE="else"; ELSE="else";
THEN="then"; THEN="then";
@ -92,7 +92,7 @@ tokens
ON="on"; ON="on";
WITH="with"; WITH="with";
// -- EJBQL tokens -- // -- JPAQL tokens --
BOTH="both"; BOTH="both";
EMPTY="empty"; EMPTY="empty";
LEADING="leading"; LEADING="leading";
@ -108,7 +108,8 @@ tokens
AGGREGATE; // One of the aggregate functions (e.g. min, max, avg) AGGREGATE; // One of the aggregate functions (e.g. min, max, avg)
ALIAS; ALIAS;
CONSTRUCTOR; CONSTRUCTOR;
CASE2; CASE2; // a "simple case statement", whereas CASE represents a "searched case statement"
CAST;
EXPR_LIST; EXPR_LIST;
FILTER_ENTITY; // FROM element injected because of a filter expression (happens during compilation phase 2) FILTER_ENTITY; // FROM element injected because of a filter expression (happens during compilation phase 2)
IN_LIST; IN_LIST;
@ -666,6 +667,7 @@ quantifiedExpression
// * function : differentiated from method call via explicit keyword // * function : differentiated from method call via explicit keyword
atom atom
: { validateSoftKeyword("function") && LA(2) == OPEN && LA(3) == QUOTED_STRING }? jpaFunctionSyntax : { validateSoftKeyword("function") && LA(2) == OPEN && LA(3) == QUOTED_STRING }? jpaFunctionSyntax
| { validateSoftKeyword("cast") && LA(2) == OPEN }? castFunction
| primaryExpression | primaryExpression
( (
DOT^ identifier DOT^ identifier
@ -677,12 +679,38 @@ atom
jpaFunctionSyntax! jpaFunctionSyntax!
: i:IDENT OPEN n:QUOTED_STRING COMMA a:exprList CLOSE { : i:IDENT OPEN n:QUOTED_STRING COMMA a:exprList CLOSE {
final String functionName = unquote( #n.getText() );
if ( functionName.equalsIgnoreCase( "cast" ) ) {
#i.setType( CAST );
#i.setText( #i.getText() + " (" + functionName + ")" );
final AST expression = #a.getFirstChild();
final AST type = expression.getNextSibling();
#jpaFunctionSyntax = #( #i, expression, type );
}
else {
#i.setType( METHOD_CALL ); #i.setType( METHOD_CALL );
#i.setText( #i.getText() + " (" + #n.getText() + ")" ); #i.setText( #i.getText() + " (" + functionName + ")" );
#jpaFunctionSyntax = #( #i, [IDENT, unquote( #n.getText() )], #a ); #jpaFunctionSyntax = #( #i, [IDENT, unquote( #n.getText() )], #a );
} }
}
; ;
castFunction!
: c:IDENT OPEN e:expression (AS)? t:castTargetType CLOSE {
#c.setType( CAST );
#castFunction = #( #c, #e, #t );
}
;
castTargetType
// the cast target type is Hibernate type name which is either:
// 1) a simple identifier
// 2) a simple identifier-(dot-identifier)* sequence
: identifier { handleDotIdent(); } ( options { greedy=true; } : DOT^ identifier )*
;
// level 0 - the basic element of an expression // level 0 - the basic element of an expression
primaryExpression primaryExpression
: identPrimary ( options {greedy=true;} : DOT^ "class" )? : identPrimary ( options {greedy=true;} : DOT^ "class" )?

View File

@ -465,12 +465,17 @@ methodCall
: #(m:METHOD_CALL i:METHOD_NAME { beginFunctionTemplate(m,i); } : #(m:METHOD_CALL i:METHOD_NAME { beginFunctionTemplate(m,i); }
( #(EXPR_LIST (arguments)? ) )? ( #(EXPR_LIST (arguments)? ) )?
{ endFunctionTemplate(m); } ) { endFunctionTemplate(m); } )
| #( c:CAST { beginFunctionTemplate(c,c); } expr castTargetType { endFunctionTemplate(c); } )
; ;
arguments arguments
: expr ( { commaBetweenParameters(", "); } expr )* : expr ( { commaBetweenParameters(", "); } expr )*
; ;
castTargetType
: i:IDENT { out(i); }
;
parameter parameter
: n:NAMED_PARAM { out(n); } : n:NAMED_PARAM { out(n); }
| p:PARAM { out(p); } | p:PARAM { out(p); }

View File

@ -36,6 +36,11 @@ import org.hibernate.type.Type;
* @author Gavin King * @author Gavin King
*/ */
public class CastFunction implements SQLFunction { public class CastFunction implements SQLFunction {
/**
* Singleton access
*/
public static final CastFunction INSTANCE = new CastFunction();
@Override @Override
public boolean hasArguments() { public boolean hasArguments() {
return true; return true;

View File

@ -47,6 +47,7 @@ import org.hibernate.hql.internal.antlr.HqlTokenTypes;
import org.hibernate.hql.internal.antlr.SqlTokenTypes; import org.hibernate.hql.internal.antlr.SqlTokenTypes;
import org.hibernate.hql.internal.ast.tree.AggregateNode; import org.hibernate.hql.internal.ast.tree.AggregateNode;
import org.hibernate.hql.internal.ast.tree.AssignmentSpecification; import org.hibernate.hql.internal.ast.tree.AssignmentSpecification;
import org.hibernate.hql.internal.ast.tree.CastFunctionNode;
import org.hibernate.hql.internal.ast.tree.CollectionFunction; import org.hibernate.hql.internal.ast.tree.CollectionFunction;
import org.hibernate.hql.internal.ast.tree.ConstructorNode; import org.hibernate.hql.internal.ast.tree.ConstructorNode;
import org.hibernate.hql.internal.ast.tree.DeleteStatement; import org.hibernate.hql.internal.ast.tree.DeleteStatement;
@ -1076,6 +1077,12 @@ public class HqlSqlWalker extends HqlSqlBaseWalker implements ErrorReporter, Par
methodNode.resolve( inSelect ); methodNode.resolve( inSelect );
} }
@Override
protected void processCastFunction(AST castFunctionCall, boolean inSelect) throws SemanticException {
CastFunctionNode castFunctionNode = (CastFunctionNode) castFunctionCall;
castFunctionNode.resolve( inSelect );
}
@Override @Override
protected void processAggregation(AST node, boolean inSelect) throws SemanticException { protected void processAggregation(AST node, boolean inSelect) throws SemanticException {
AggregateNode aggregateNode = (AggregateNode) node; AggregateNode aggregateNode = (AggregateNode) node;

View File

@ -32,6 +32,7 @@ import org.hibernate.hql.internal.ast.tree.BetweenOperatorNode;
import org.hibernate.hql.internal.ast.tree.BinaryArithmeticOperatorNode; import org.hibernate.hql.internal.ast.tree.BinaryArithmeticOperatorNode;
import org.hibernate.hql.internal.ast.tree.BinaryLogicOperatorNode; import org.hibernate.hql.internal.ast.tree.BinaryLogicOperatorNode;
import org.hibernate.hql.internal.ast.tree.BooleanLiteralNode; import org.hibernate.hql.internal.ast.tree.BooleanLiteralNode;
import org.hibernate.hql.internal.ast.tree.CastFunctionNode;
import org.hibernate.hql.internal.ast.tree.SearchedCaseNode; import org.hibernate.hql.internal.ast.tree.SearchedCaseNode;
import org.hibernate.hql.internal.ast.tree.SimpleCaseNode; import org.hibernate.hql.internal.ast.tree.SimpleCaseNode;
import org.hibernate.hql.internal.ast.tree.CollectionFunction; import org.hibernate.hql.internal.ast.tree.CollectionFunction;
@ -133,6 +134,8 @@ public class SqlASTFactory extends ASTFactory implements HqlSqlTokenTypes {
return SqlFragment.class; return SqlFragment.class;
case METHOD_CALL: case METHOD_CALL:
return MethodNode.class; return MethodNode.class;
case CAST:
return CastFunctionNode.class;
case ELEMENTS: case ELEMENTS:
case INDICES: case INDICES:
return CollectionFunction.class; return CollectionFunction.class;

View File

@ -208,7 +208,12 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
else { else {
// this function has a registered SQLFunction -> redirect output and catch the arguments // this function has a registered SQLFunction -> redirect output and catch the arguments
outputStack.addFirst( writer ); outputStack.addFirst( writer );
writer = new FunctionArguments(); if ( node.getType() == CAST ) {
writer = new CastFunctionArguments();
}
else {
writer = new StandardFunctionArguments();
}
} }
} }
@ -222,7 +227,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
else { else {
final Type functionType = functionNode.getFirstArgumentType(); final Type functionType = functionNode.getFirstArgumentType();
// this function has a registered SQLFunction -> redirect output and catch the arguments // this function has a registered SQLFunction -> redirect output and catch the arguments
FunctionArguments functionArguments = (FunctionArguments) writer; FunctionArgumentsCollectingWriter functionArguments = (FunctionArgumentsCollectingWriter) writer;
writer = outputStack.removeFirst(); writer = outputStack.removeFirst();
out( sqlFunction.render( functionType, functionArguments.getArgs(), sessionFactory ) ); out( sqlFunction.render( functionType, functionArguments.getArgs(), sessionFactory ) );
} }
@ -246,11 +251,15 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
void commaBetweenParameters(String comma); void commaBetweenParameters(String comma);
} }
interface FunctionArgumentsCollectingWriter extends SqlWriter {
public List getArgs();
}
/** /**
* SQL function processing code redirects generated SQL output to an instance of this class * SQL function processing code redirects generated SQL output to an instance of this class
* which catches function arguments. * which catches function arguments.
*/ */
class FunctionArguments implements SqlWriter { class StandardFunctionArguments implements FunctionArgumentsCollectingWriter {
private int argInd; private int argInd;
private final List<String> args = new ArrayList<String>( 3 ); private final List<String> args = new ArrayList<String>( 3 );
@ -274,6 +283,28 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
} }
} }
/**
* SQL function processing code redirects generated SQL output to an instance of this class
* which catches function arguments.
*/
class CastFunctionArguments implements FunctionArgumentsCollectingWriter {
private final List<String> args = new ArrayList<String>( 3 );
@Override
public void clause(String clause) {
args.add( clause );
}
@Override
public void commaBetweenParameters(String comma) {
// todo : should this be an exception? Its not likely to end well if this method is called here...
}
public List getArgs() {
return args;
}
}
/** /**
* The default SQL writer. * The default SQL writer.
*/ */

View File

@ -0,0 +1,115 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2014, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.hql.internal.ast.tree;
import org.hibernate.QueryException;
import org.hibernate.dialect.function.CastFunction;
import org.hibernate.dialect.function.SQLFunction;
import org.hibernate.hql.internal.ast.util.ColumnHelper;
import org.hibernate.type.Type;
import antlr.SemanticException;
/**
* Represents a cast function call. We handle this specially because its type
* argument has a semantic meaning to the HQL query (its not just pass through).
*
* @author Steve Ebersole
*/
public class CastFunctionNode extends AbstractSelectExpression implements FunctionNode {
private SQLFunction dialectCastFunction;
private Node expressionNode;
private IdentNode typeNode;
private Type castType;
/**
* Called from the hql-sql grammar after the children of the CAST have been resolved.
*
* @param inSelect Is this call part of the SELECT clause?
*/
public void resolve(boolean inSelect) {
this.dialectCastFunction = getSessionFactoryHelper().findSQLFunction( "cast" );
if ( dialectCastFunction == null ) {
dialectCastFunction = CastFunction.INSTANCE;
}
this.expressionNode = (Node) getFirstChild();
if ( expressionNode == null ) {
throw new QueryException( "Could not resolve expression to CAST" );
}
if ( SqlNode.class.isInstance( expressionNode ) ) {
final Type expressionType = ( (SqlNode) expressionNode ).getDataType();
if ( expressionType != null ) {
if ( expressionType.isEntityType() ) {
throw new QueryException( "Expression to CAST cannot be an entity : " + expressionNode.getText() );
}
if ( expressionType.isComponentType() ) {
throw new QueryException( "Expression to CAST cannot be a composite : " + expressionNode.getText() );
}
if ( expressionType.isCollectionType() ) {
throw new QueryException( "Expression to CAST cannot be a collection : " + expressionNode.getText() );
}
}
}
this.typeNode = (IdentNode) expressionNode.getNextSibling();
if ( typeNode == null ) {
throw new QueryException( "Could not resolve requested type for CAST" );
}
final String typeName = typeNode.getText();
this.castType = getSessionFactoryHelper().getFactory().getTypeResolver().heuristicType( typeName );
if ( castType == null ) {
throw new QueryException( "Could not resolve requested type for CAST : " + typeName );
}
if ( castType.isEntityType() ) {
throw new QueryException( "CAST target type cannot be an entity : " + expressionNode.getText() );
}
if ( castType.isComponentType() ) {
throw new QueryException( "CAST target type cannot be a composite : " + expressionNode.getText() );
}
if ( castType.isCollectionType() ) {
throw new QueryException( "CAST target type cannot be a collection : " + expressionNode.getText() );
}
setDataType( castType );
}
@Override
public SQLFunction getSQLFunction() {
return dialectCastFunction;
}
@Override
public Type getFirstArgumentType() {
return castType;
}
@Override
public void setScalarColumnText(int i) throws SemanticException {
ColumnHelper.generateSingleScalarColumn( this, i );
}
}

View File

@ -0,0 +1,116 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2014, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.test.hql;
import javax.persistence.Entity;
import javax.persistence.Id;
import org.hibernate.Session;
import org.hibernate.testing.junit4.BaseCoreFunctionalTestCase;
import org.junit.Test;
/**
* @author Steve Ebersole
*/
public class CastFunctionTest extends BaseCoreFunctionalTestCase {
@Entity( name="MyEntity" )
public static class MyEntity {
@Id
private Integer id;
private String name;
private Number theLostNumber;
}
@Override
protected Class<?>[] getAnnotatedClasses() {
return new Class[] { MyEntity.class };
}
@Test
public void testStringCasting() {
Session s = openSession();
s.beginTransaction();
// using the short name
s.createQuery( "select cast(e.theLostNumber as string) from MyEntity e" ).list();
// using the java class name
s.createQuery( "select cast(e.theLostNumber as java.lang.String) from MyEntity e" ).list();
// using the fqn Hibernate Type name
s.createQuery( "select cast(e.theLostNumber as org.hibernate.type.StringType) from MyEntity e" ).list();
s.getTransaction().commit();
s.close();
}
@Test
public void testIntegerCasting() {
Session s = openSession();
s.beginTransaction();
// using the short name
s.createQuery( "select cast(e.theLostNumber as integer) from MyEntity e" ).list();
// using the java class name (primitive)
s.createQuery( "select cast(e.theLostNumber as int) from MyEntity e" ).list();
// using the java class name
s.createQuery( "select cast(e.theLostNumber as java.lang.Integer) from MyEntity e" ).list();
// using the fqn Hibernate Type name
s.createQuery( "select cast(e.theLostNumber as org.hibernate.type.IntegerType) from MyEntity e" ).list();
s.getTransaction().commit();
s.close();
}
@Test
public void testLongCasting() {
Session s = openSession();
s.beginTransaction();
// using the short name (also the primitive name)
s.createQuery( "select cast(e.theLostNumber as long) from MyEntity e" ).list();
// using the java class name
s.createQuery( "select cast(e.theLostNumber as java.lang.Long) from MyEntity e" ).list();
// using the fqn Hibernate Type name
s.createQuery( "select cast(e.theLostNumber as org.hibernate.type.LongType) from MyEntity e" ).list();
s.getTransaction().commit();
s.close();
}
@Test
public void testFloatCasting() {
Session s = openSession();
s.beginTransaction();
// using the short name (also the primitive name)
s.createQuery( "select cast(e.theLostNumber as float) from MyEntity e" ).list();
// using the java class name
s.createQuery( "select cast(e.theLostNumber as java.lang.Float) from MyEntity e" ).list();
// using the fqn Hibernate Type name
s.createQuery( "select cast(e.theLostNumber as org.hibernate.type.FloatType) from MyEntity e" ).list();
s.getTransaction().commit();
s.close();
}
}