diff --git a/core/src/main/antlr/hql-sql.g b/core/src/main/antlr/hql-sql.g index 3e177b3540..8bb359be50 100644 --- a/core/src/main/antlr/hql-sql.g +++ b/core/src/main/antlr/hql-sql.g @@ -197,6 +197,8 @@ tokens protected void processFunction(AST functionCall,boolean inSelect) throws SemanticException { } + protected void processAggregation(AST node, boolean inSelect) throws SemanticException { } + protected void processConstructor(AST constructor) throws SemanticException { } protected AST generateNamedParameter(AST delimiterNode, AST nameNode) throws SemanticException { diff --git a/core/src/main/antlr/sql-gen.g b/core/src/main/antlr/sql-gen.g index 064810ce3b..79f7a697a0 100644 --- a/core/src/main/antlr/sql-gen.g +++ b/core/src/main/antlr/sql-gen.g @@ -406,10 +406,13 @@ caseExpr ; aggregate - : #(a:AGGREGATE { out(a); out("("); } expr { out(")"); } ) + : #( + a:AGGREGATE { beginFunctionTemplate( a, a ); } + expr + { endFunctionTemplate( a ); } + ) ; - methodCall : #(m:METHOD_CALL i:METHOD_NAME { beginFunctionTemplate(m,i); } ( #(EXPR_LIST (arguments)? ) )? diff --git a/core/src/main/java/org/hibernate/dialect/DB2Dialect.java b/core/src/main/java/org/hibernate/dialect/DB2Dialect.java index 65c2179b9b..578b9bb84a 100644 --- a/core/src/main/java/org/hibernate/dialect/DB2Dialect.java +++ b/core/src/main/java/org/hibernate/dialect/DB2Dialect.java @@ -31,6 +31,7 @@ import java.sql.Types; import org.hibernate.Hibernate; import org.hibernate.cfg.Environment; +import org.hibernate.dialect.function.AvgWithArgumentCastFunction; import org.hibernate.dialect.function.NoArgSQLFunction; import org.hibernate.dialect.function.SQLFunctionTemplate; import org.hibernate.dialect.function.StandardSQLFunction; @@ -64,6 +65,8 @@ public class DB2Dialect extends Dialect { registerColumnType( Types.LONGVARCHAR, "long varchar" ); registerColumnType( Types.LONGVARBINARY, "long varchar for bit data" ); + registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) ); + registerFunction("abs", new StandardSQLFunction("abs") ); registerFunction("absval", new StandardSQLFunction("absval") ); registerFunction("sign", new StandardSQLFunction("sign", Hibernate.INTEGER) ); diff --git a/core/src/main/java/org/hibernate/dialect/Dialect.java b/core/src/main/java/org/hibernate/dialect/Dialect.java index 24e4c96995..35a4da698a 100644 --- a/core/src/main/java/org/hibernate/dialect/Dialect.java +++ b/core/src/main/java/org/hibernate/dialect/Dialect.java @@ -44,6 +44,7 @@ import org.hibernate.MappingException; import org.hibernate.QueryException; import org.hibernate.LockOptions; import org.hibernate.cfg.Environment; +import org.hibernate.dialect.function.AvgFunction; import org.hibernate.dialect.function.CastFunction; import org.hibernate.dialect.function.SQLFunction; import org.hibernate.dialect.function.SQLFunctionTemplate; @@ -128,24 +129,7 @@ public abstract class Dialect { } ); - STANDARD_AGGREGATE_FUNCTIONS.put( - "avg", - new StandardSQLFunction("avg") { - public Type getReturnType(Type columnType, Mapping mapping) throws QueryException { - int[] sqlTypes; - try { - sqlTypes = columnType.sqlTypes( mapping ); - } - catch ( MappingException me ) { - throw new QueryException( me ); - } - if ( sqlTypes.length != 1 ) { - throw new QueryException( "multi-column type in avg()" ); - } - return Hibernate.DOUBLE; - } - } - ); + STANDARD_AGGREGATE_FUNCTIONS.put( "avg", new AvgFunction() ); STANDARD_AGGREGATE_FUNCTIONS.put( "max", new StandardSQLFunction("max") ); STANDARD_AGGREGATE_FUNCTIONS.put( "min", new StandardSQLFunction("min") ); diff --git a/core/src/main/java/org/hibernate/dialect/H2Dialect.java b/core/src/main/java/org/hibernate/dialect/H2Dialect.java index 17bf0d8af3..48cda5eb2a 100644 --- a/core/src/main/java/org/hibernate/dialect/H2Dialect.java +++ b/core/src/main/java/org/hibernate/dialect/H2Dialect.java @@ -28,6 +28,7 @@ import java.sql.Types; import org.hibernate.Hibernate; import org.hibernate.cfg.Environment; +import org.hibernate.dialect.function.AvgWithArgumentCastFunction; import org.hibernate.dialect.function.NoArgSQLFunction; import org.hibernate.dialect.function.StandardSQLFunction; import org.hibernate.dialect.function.VarArgsSQLFunction; @@ -84,6 +85,9 @@ public class H2Dialect extends Dialect { registerColumnType( Types.BLOB, "blob" ); registerColumnType( Types.CLOB, "clob" ); + // Aggregations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) ); + // select topic, syntax from information_schema.help // where section like 'Function%' order by section, topic // diff --git a/core/src/main/java/org/hibernate/dialect/HSQLDialect.java b/core/src/main/java/org/hibernate/dialect/HSQLDialect.java index 7c6774a132..1a0f1a375c 100644 --- a/core/src/main/java/org/hibernate/dialect/HSQLDialect.java +++ b/core/src/main/java/org/hibernate/dialect/HSQLDialect.java @@ -32,6 +32,7 @@ import org.hibernate.Hibernate; import org.hibernate.LockMode; import org.hibernate.StaleObjectStateException; import org.hibernate.JDBCException; +import org.hibernate.dialect.function.AvgWithArgumentCastFunction; import org.hibernate.engine.SessionImplementor; import org.hibernate.persister.entity.Lockable; import org.hibernate.cfg.Environment; @@ -83,6 +84,8 @@ public class HSQLDialect extends Dialect { registerColumnType( Types.LONGVARBINARY, "longvarbinary" ); registerColumnType( Types.LONGVARCHAR, "longvarchar" ); + registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) ); + registerFunction( "ascii", new StandardSQLFunction( "ascii", Hibernate.INTEGER ) ); registerFunction( "char", new StandardSQLFunction( "char", Hibernate.CHARACTER ) ); registerFunction( "lower", new StandardSQLFunction( "lower" ) ); diff --git a/core/src/main/java/org/hibernate/dialect/function/AvgFunction.java b/core/src/main/java/org/hibernate/dialect/function/AvgFunction.java new file mode 100644 index 0000000000..f50f69abb9 --- /dev/null +++ b/core/src/main/java/org/hibernate/dialect/function/AvgFunction.java @@ -0,0 +1,71 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as + * indicated by the @author tags or express copyright attribution + * statements applied by the authors. All third-party contributions are + * distributed under license by Red Hat Inc. + * + * This copyrighted material is made available to anyone wishing to use, modify, + * copy, or redistribute it subject to the terms and conditions of the GNU + * Lesser General Public License, as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License + * for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this distribution; if not, write to: + * Free Software Foundation, Inc. + * 51 Franklin Street, Fifth Floor + * Boston, MA 02110-1301 USA + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.MappingException; +import org.hibernate.QueryException; +import org.hibernate.engine.Mapping; +import org.hibernate.engine.SessionFactoryImplementor; +import org.hibernate.type.DoubleType; +import org.hibernate.type.Type; + +/** + * The basic JPA spec compliant definition poAVG aggregation function. + * + * @author Steve Ebersole + */ +public class AvgFunction implements SQLFunction { + public final Type getReturnType(Type columnType, Mapping mapping) throws QueryException { + int[] sqlTypes; + try { + sqlTypes = columnType.sqlTypes( mapping ); + } + catch ( MappingException me ) { + throw new QueryException( me ); + } + if ( sqlTypes.length != 1 ) { + throw new QueryException( "multiple-column type in avg()" ); + } + return DoubleType.INSTANCE; + } + + public final boolean hasArguments() { + return true; + } + + public final boolean hasParenthesesIfNoArguments() { + return true; + } + + public String render(List args, SessionFactoryImplementor factory) throws QueryException { + return "avg(" + args.get( 0 ) + ")"; + } + + @Override + public final String toString() { + return "avg"; + } +} diff --git a/core/src/main/java/org/hibernate/dialect/function/AvgWithArgumentCastFunction.java b/core/src/main/java/org/hibernate/dialect/function/AvgWithArgumentCastFunction.java new file mode 100644 index 0000000000..9b6ee8bfa2 --- /dev/null +++ b/core/src/main/java/org/hibernate/dialect/function/AvgWithArgumentCastFunction.java @@ -0,0 +1,54 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as + * indicated by the @author tags or express copyright attribution + * statements applied by the authors. All third-party contributions are + * distributed under license by Red Hat Inc. + * + * This copyrighted material is made available to anyone wishing to use, modify, + * copy, or redistribute it subject to the terms and conditions of the GNU + * Lesser General Public License, as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License + * for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this distribution; if not, write to: + * Free Software Foundation, Inc. + * 51 Franklin Street, Fifth Floor + * Boston, MA 02110-1301 USA + */ +package org.hibernate.dialect.function; + +import java.util.List; + +import org.hibernate.QueryException; +import org.hibernate.engine.SessionFactoryImplementor; + +/** + * Some databases strictly return the type of the of the aggregation value for AVG which is + * problematic in the case of averaging integers because the decimals will be dropped. The usual workaround + * is to cast the integer argument as some form of double/decimal. + *

+ * A downside to this approach is that we always wrap the avg() argument in a cast even though we may not need or want + * to. A more full-featured solution would be defining {@link SQLFunction} such that we render based on the first + * argument; essentially have {@link SQLFunction} describe the basic metadata about the function and merge the + * {@link SQLFunction#getReturnType} and {@link SQLFunction#render} methods into a + * + * @author Steve Ebersole + */ +public class AvgWithArgumentCastFunction extends AvgFunction { + private final TemplateRenderer renderer; + + public AvgWithArgumentCastFunction(String castType) { + renderer = new TemplateRenderer( "avg(cast(?1 as " + castType + "))" ); + } + + @Override + public String render(List args, SessionFactoryImplementor factory) throws QueryException { + return renderer.render( args, factory ); + } +} diff --git a/core/src/main/java/org/hibernate/dialect/function/SQLFunction.java b/core/src/main/java/org/hibernate/dialect/function/SQLFunction.java index cc45514425..9a570b8391 100644 --- a/core/src/main/java/org/hibernate/dialect/function/SQLFunction.java +++ b/core/src/main/java/org/hibernate/dialect/function/SQLFunction.java @@ -42,18 +42,6 @@ import org.hibernate.type.Type; * @author David Channon */ public interface SQLFunction { - /** - * The return type of the function. May be either a concrete type which - * is preset, or variable depending upon the type of the first function - * argument. - * - * @param columnType the type of the first argument - * @param mapping The mapping source. - * @return The type to be expected as a return. - * @throws org.hibernate.QueryException Indicates an issue resolving the return type. - */ - public Type getReturnType(Type columnType, Mapping mapping) throws QueryException; - /** * Does this function have any arguments? * @@ -68,14 +56,34 @@ public interface SQLFunction { */ public boolean hasParenthesesIfNoArguments(); + /** + * The return type of the function. May be either a concrete type which + * is preset, or variable depending upon the type of the first function + * argument. + * + * @param columnType the type of the first argument + * @param mapping The mapping source. + * + * @return The type to be expected as a return. + * + * @throws org.hibernate.QueryException Indicates an issue resolving the return type. + * + * @deprecated See http://opensource.atlassian.com/projects/hibernate/browse/HHH-5212 + */ + public Type getReturnType(Type columnType, Mapping mapping) throws QueryException; + /** * Render the function call as SQL fragment. * * @param args The function arguments * @param factory The SessionFactory + * * @return The rendered function call + * * @throws org.hibernate.QueryException Indicates a problem rendering the * function call. + * + * @deprecated See http://opensource.atlassian.com/projects/hibernate/browse/HHH-5212 */ public String render(List args, SessionFactoryImplementor factory) throws QueryException; } diff --git a/core/src/main/java/org/hibernate/dialect/function/SQLFunctionTemplate.java b/core/src/main/java/org/hibernate/dialect/function/SQLFunctionTemplate.java index 6ece9287f3..82b43c5cb5 100755 --- a/core/src/main/java/org/hibernate/dialect/function/SQLFunctionTemplate.java +++ b/core/src/main/java/org/hibernate/dialect/function/SQLFunctionTemplate.java @@ -1,10 +1,10 @@ /* * Hibernate, Relational Persistence for Idiomatic Java * - * Copyright (c) 2008, Red Hat Middleware LLC or third-party contributors as + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as * indicated by the @author tags or express copyright attribution * statements applied by the authors. All third-party contributions are - * distributed under license by Red Hat Middleware LLC. + * distributed under license by Red Hat Inc. * * This copyrighted material is made available to anyone wishing to use, modify, * copy, or redistribute it subject to the terms and conditions of the GNU @@ -20,7 +20,6 @@ * Free Software Foundation, Inc. * 51 Franklin Street, Fifth Floor * Boston, MA 02110-1301 USA - * */ package org.hibernate.dialect.function; @@ -29,7 +28,6 @@ import org.hibernate.engine.Mapping; import org.hibernate.engine.SessionFactoryImplementor; import org.hibernate.type.Type; -import java.util.ArrayList; import java.util.List; /** @@ -45,102 +43,51 @@ import java.util.List; */ public class SQLFunctionTemplate implements SQLFunction { private final Type type; - private final boolean hasArguments; + private final TemplateRenderer renderer; private final boolean hasParenthesesIfNoArgs; - private final String template; - private final String[] chunks; - private final int[] paramIndexes; - public SQLFunctionTemplate(Type type, String template) { this( type, template, true ); } public SQLFunctionTemplate(Type type, String template, boolean hasParenthesesIfNoArgs) { this.type = type; - this.template = template; - - List chunkList = new ArrayList(); - List paramList = new ArrayList(); - StringBuffer chunk = new StringBuffer( 10 ); - StringBuffer index = new StringBuffer( 2 ); - - for ( int i = 0; i < template.length(); ++i ) { - char c = template.charAt( i ); - if ( c == '?' ) { - chunkList.add( chunk.toString() ); - chunk.delete( 0, chunk.length() ); - - while ( ++i < template.length() ) { - c = template.charAt( i ); - if ( Character.isDigit( c ) ) { - index.append( c ); - } - else { - chunk.append( c ); - break; - } - } - - paramList.add( new Integer( Integer.parseInt( index.toString() ) - 1 ) ); - index.delete( 0, index.length() ); - } - else { - chunk.append( c ); - } - } - - if ( chunk.length() > 0 ) { - chunkList.add( chunk.toString() ); - } - - chunks = ( String[] ) chunkList.toArray( new String[chunkList.size()] ); - paramIndexes = new int[paramList.size()]; - for ( int i = 0; i < paramIndexes.length; ++i ) { - paramIndexes[i] = ( ( Integer ) paramList.get( i ) ).intValue(); - } - - hasArguments = paramIndexes.length > 0; + this.renderer = new TemplateRenderer( template ); this.hasParenthesesIfNoArgs = hasParenthesesIfNoArgs; } /** - * Applies the template to passed in arguments. - * @param args function arguments - * - * @return generated SQL function call + * {@inheritDoc} */ public String render(List args, SessionFactoryImplementor factory) { - StringBuffer buf = new StringBuffer(); - for ( int i = 0; i < chunks.length; ++i ) { - if ( i < paramIndexes.length ) { - Object arg = paramIndexes[i] < args.size() ? args.get( paramIndexes[i] ) : null; - if ( arg != null ) { - buf.append( chunks[i] ).append( arg ); - } - } - else { - buf.append( chunks[i] ); - } - } - return buf.toString(); + return renderer.render( args, factory ); } - // SQLFunction implementation - + /** + * {@inheritDoc} + */ public Type getReturnType(Type columnType, Mapping mapping) throws QueryException { return type; } + /** + * {@inheritDoc} + */ public boolean hasArguments() { - return hasArguments; + return renderer.getAnticipatedNumberOfArguments() > 0; } + /** + * {@inheritDoc} + */ public boolean hasParenthesesIfNoArguments() { return hasParenthesesIfNoArgs; } + /** + * {@inheritDoc} + */ public String toString() { - return template; + return renderer.getTemplate(); } } diff --git a/core/src/main/java/org/hibernate/dialect/function/TemplateRenderer.java b/core/src/main/java/org/hibernate/dialect/function/TemplateRenderer.java new file mode 100644 index 0000000000..fd54e772e5 --- /dev/null +++ b/core/src/main/java/org/hibernate/dialect/function/TemplateRenderer.java @@ -0,0 +1,121 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as + * indicated by the @author tags or express copyright attribution + * statements applied by the authors. All third-party contributions are + * distributed under license by Red Hat Inc. + * + * This copyrighted material is made available to anyone wishing to use, modify, + * copy, or redistribute it subject to the terms and conditions of the GNU + * Lesser General Public License, as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License + * for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this distribution; if not, write to: + * Free Software Foundation, Inc. + * 51 Franklin Street, Fifth Floor + * Boston, MA 02110-1301 USA + */ +package org.hibernate.dialect.function; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.hibernate.engine.SessionFactoryImplementor; + +/** + * Delegate for handling function "templates". + * + * @author Steve Ebersole + */ +public class TemplateRenderer { + private static final Logger log = LoggerFactory.getLogger( TemplateRenderer.class ); + + private final String template; + private final String[] chunks; + private final int[] paramIndexes; + + @SuppressWarnings({ "UnnecessaryUnboxing" }) + public TemplateRenderer(String template) { + this.template = template; + + List chunkList = new ArrayList(); + List paramList = new ArrayList(); + StringBuffer chunk = new StringBuffer( 10 ); + StringBuffer index = new StringBuffer( 2 ); + + for ( int i = 0; i < template.length(); ++i ) { + char c = template.charAt( i ); + if ( c == '?' ) { + chunkList.add( chunk.toString() ); + chunk.delete( 0, chunk.length() ); + + while ( ++i < template.length() ) { + c = template.charAt( i ); + if ( Character.isDigit( c ) ) { + index.append( c ); + } + else { + chunk.append( c ); + break; + } + } + + paramList.add( Integer.valueOf( index.toString() ) ); + index.delete( 0, index.length() ); + } + else { + chunk.append( c ); + } + } + + if ( chunk.length() > 0 ) { + chunkList.add( chunk.toString() ); + } + + chunks = chunkList.toArray( new String[chunkList.size()] ); + paramIndexes = new int[paramList.size()]; + for ( int i = 0; i < paramIndexes.length; ++i ) { + paramIndexes[i] = paramList.get( i ).intValue(); + } + } + + public String getTemplate() { + return template; + } + + public int getAnticipatedNumberOfArguments() { + return paramIndexes.length; + } + + @SuppressWarnings({ "UnusedDeclaration" }) + public String render(List args, SessionFactoryImplementor factory) { + int numberOfArguments = args.size(); + if ( getAnticipatedNumberOfArguments() > 0 && numberOfArguments != getAnticipatedNumberOfArguments() ) { + log.warn( "Function template anticipated {} arguments, but {} arguments encountered", + getAnticipatedNumberOfArguments(), numberOfArguments ); + } + StringBuffer buf = new StringBuffer(); + for ( int i = 0; i < chunks.length; ++i ) { + if ( i < paramIndexes.length ) { + final int index = paramIndexes[i] - 1; + final Object arg = index < numberOfArguments ? args.get( index ) : null; + if ( arg != null ) { + buf.append( chunks[i] ).append( arg ); + } + } + else { + buf.append( chunks[i] ); + } + } + return buf.toString(); + } +} diff --git a/core/src/main/java/org/hibernate/hql/ast/HqlSqlWalker.java b/core/src/main/java/org/hibernate/hql/ast/HqlSqlWalker.java index 70e6529745..9b5489214e 100644 --- a/core/src/main/java/org/hibernate/hql/ast/HqlSqlWalker.java +++ b/core/src/main/java/org/hibernate/hql/ast/HqlSqlWalker.java @@ -47,6 +47,7 @@ import org.hibernate.hql.antlr.HqlSqlBaseWalker; import org.hibernate.hql.antlr.HqlSqlTokenTypes; import org.hibernate.hql.antlr.HqlTokenTypes; import org.hibernate.hql.antlr.SqlTokenTypes; +import org.hibernate.hql.ast.tree.AggregateNode; import org.hibernate.hql.ast.tree.AssignmentSpecification; import org.hibernate.hql.ast.tree.CollectionFunction; import org.hibernate.hql.ast.tree.ConstructorNode; @@ -977,6 +978,11 @@ public class HqlSqlWalker extends HqlSqlBaseWalker implements ErrorReporter, Par methodNode.resolve( inSelect ); } + protected void processAggregation(AST node, boolean inSelect) throws SemanticException { + AggregateNode aggregateNode = ( AggregateNode ) node; + aggregateNode.resolve(); + } + protected void processConstructor(AST constructor) throws SemanticException { ConstructorNode constructorNode = ( ConstructorNode ) constructor; constructorNode.prepare(); diff --git a/core/src/main/java/org/hibernate/hql/ast/SqlGenerator.java b/core/src/main/java/org/hibernate/hql/ast/SqlGenerator.java index f6433e40d8..99cbbb6b84 100644 --- a/core/src/main/java/org/hibernate/hql/ast/SqlGenerator.java +++ b/core/src/main/java/org/hibernate/hql/ast/SqlGenerator.java @@ -32,6 +32,7 @@ import java.util.Arrays; import antlr.RecognitionException; import antlr.collections.AST; import org.hibernate.QueryException; +import org.hibernate.hql.ast.tree.FunctionNode; import org.hibernate.util.StringHelper; import org.hibernate.param.ParameterSpecification; import org.hibernate.dialect.function.SQLFunction; @@ -71,7 +72,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter { private ParseErrorHandler parseErrorHandler; private SessionFactoryImplementor sessionFactory; - private LinkedList outputStack = new LinkedList(); + private LinkedList outputStack = new LinkedList(); private final ASTPrinter printer = new ASTPrinter( SqlTokenTypes.class ); private List collectedParameters = new ArrayList(); @@ -178,31 +179,33 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter { } } - protected void beginFunctionTemplate(AST m, AST i) { - MethodNode methodNode = ( MethodNode ) m; - SQLFunction template = methodNode.getSQLFunction(); - if ( template == null ) { - // if template is null we just write the function out as it appears in the hql statement - super.beginFunctionTemplate( m, i ); + protected void beginFunctionTemplate(AST node, AST nameNode) { + // NOTE for AGGREGATE both nodes are the same; for METHOD the first is the METHOD, the second is the + // METHOD_NAME + FunctionNode functionNode = ( FunctionNode ) node; + SQLFunction sqlFunction = functionNode.getSQLFunction(); + if ( sqlFunction == null ) { + // if SQLFunction is null we just write the function out as it appears in the hql statement + super.beginFunctionTemplate( node, nameNode ); } else { - // this function has a template -> redirect output and catch the arguments + // this function has a registered SQLFunction -> redirect output and catch the arguments outputStack.addFirst( writer ); writer = new FunctionArguments(); } } - protected void endFunctionTemplate(AST m) { - MethodNode methodNode = ( MethodNode ) m; - SQLFunction template = methodNode.getSQLFunction(); - if ( template == null ) { - super.endFunctionTemplate( m ); + protected void endFunctionTemplate(AST node) { + FunctionNode functionNode = ( FunctionNode ) node; + SQLFunction sqlFunction = functionNode.getSQLFunction(); + if ( sqlFunction == null ) { + super.endFunctionTemplate( node ); } else { - // this function has a template -> restore output, apply the template and write the result out - FunctionArguments functionArguments = ( FunctionArguments ) writer; // TODO: Downcast to avoid using an interface? Yuck. - writer = ( SqlWriter ) outputStack.removeFirst(); - out( template.render( functionArguments.getArgs(), sessionFactory ) ); + // this function has a registered SQLFunction -> redirect output and catch the arguments + FunctionArguments functionArguments = ( FunctionArguments ) writer; + writer = outputStack.removeFirst(); + out( sqlFunction.render( functionArguments.getArgs(), sessionFactory ) ); } } @@ -230,7 +233,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter { */ class FunctionArguments implements SqlWriter { private int argInd; - private final List args = new ArrayList( 3 ); + private final List args = new ArrayList(3); public void clause(String clause) { if ( argInd == args.size() ) { diff --git a/core/src/main/java/org/hibernate/hql/ast/tree/AggregateNode.java b/core/src/main/java/org/hibernate/hql/ast/tree/AggregateNode.java index b2fe16589e..70577f11b6 100644 --- a/core/src/main/java/org/hibernate/hql/ast/tree/AggregateNode.java +++ b/core/src/main/java/org/hibernate/hql/ast/tree/AggregateNode.java @@ -1,10 +1,10 @@ /* * Hibernate, Relational Persistence for Idiomatic Java * - * Copyright (c) 2008, Red Hat Middleware LLC or third-party contributors as + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as * indicated by the @author tags or express copyright attribution * statements applied by the authors. All third-party contributions are - * distributed under license by Red Hat Middleware LLC. + * distributed under license by Red Hat Inc. * * This copyrighted material is made available to anyone wishing to use, modify, * copy, or redistribute it subject to the terms and conditions of the GNU @@ -20,31 +20,59 @@ * Free Software Foundation, Inc. * 51 Franklin Street, Fifth Floor * Boston, MA 02110-1301 USA - * */ package org.hibernate.hql.ast.tree; +import org.hibernate.dialect.function.SQLFunction; +import org.hibernate.dialect.function.StandardSQLFunction; import org.hibernate.hql.ast.util.ColumnHelper; import org.hibernate.type.Type; import antlr.SemanticException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Represents an aggregate function i.e. min, max, sum, avg. * * @author Joshua Davis */ -public class AggregateNode extends AbstractSelectExpression implements SelectExpression { +public class AggregateNode extends AbstractSelectExpression implements SelectExpression, FunctionNode { + private static final Logger log = LoggerFactory.getLogger( AggregateNode.class ); - public AggregateNode() { + private SQLFunction sqlFunction; + + public SQLFunction getSQLFunction() { + return sqlFunction; + } + + public void resolve() { + resolveFunction(); + } + + private SQLFunction resolveFunction() { + if ( sqlFunction == null ) { + final String name = getText(); + sqlFunction = getSessionFactoryHelper().findSQLFunction( getText() ); + if ( sqlFunction == null ) { + log.info( "Could not resolve aggregate function {}; using standard definition", name ); + sqlFunction = new StandardSQLFunction( name ); + } + } + return sqlFunction; } public Type getDataType() { // Get the function return value type, based on the type of the first argument. - return getSessionFactoryHelper().findFunctionReturnType( getText(), getFirstChild() ); + return getSessionFactoryHelper().findFunctionReturnType( getText(), resolveFunction(), getFirstChild() ); } public void setScalarColumnText(int i) throws SemanticException { ColumnHelper.generateSingleScalarColumn( this, i ); } + + public boolean isScalar() throws SemanticException { + // functions in a SELECT should always be considered scalar. + return true; + } } diff --git a/core/src/main/java/org/hibernate/hql/ast/tree/FunctionNode.java b/core/src/main/java/org/hibernate/hql/ast/tree/FunctionNode.java new file mode 100644 index 0000000000..1f856a83df --- /dev/null +++ b/core/src/main/java/org/hibernate/hql/ast/tree/FunctionNode.java @@ -0,0 +1,35 @@ +/* + * Hibernate, Relational Persistence for Idiomatic Java + * + * Copyright (c) 2010, Red Hat Inc. or third-party contributors as + * indicated by the @author tags or express copyright attribution + * statements applied by the authors. All third-party contributions are + * distributed under license by Red Hat Inc. + * + * This copyrighted material is made available to anyone wishing to use, modify, + * copy, or redistribute it subject to the terms and conditions of the GNU + * Lesser General Public License, as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License + * for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this distribution; if not, write to: + * Free Software Foundation, Inc. + * 51 Franklin Street, Fifth Floor + * Boston, MA 02110-1301 USA + */ +package org.hibernate.hql.ast.tree; + +import org.hibernate.dialect.function.SQLFunction; + +/** + * Identifies a node which models a SQL function. + * + * @author Steve Ebersole + */ +public interface FunctionNode { + public SQLFunction getSQLFunction(); +} diff --git a/core/src/main/java/org/hibernate/hql/ast/tree/MethodNode.java b/core/src/main/java/org/hibernate/hql/ast/tree/MethodNode.java index c5d5ad7f40..c6abeb3e35 100644 --- a/core/src/main/java/org/hibernate/hql/ast/tree/MethodNode.java +++ b/core/src/main/java/org/hibernate/hql/ast/tree/MethodNode.java @@ -47,7 +47,7 @@ import org.slf4j.LoggerFactory; * * @author josh */ -public class MethodNode extends AbstractSelectExpression implements SelectExpression { +public class MethodNode extends AbstractSelectExpression implements SelectExpression, FunctionNode { private static final Logger log = LoggerFactory.getLogger( MethodNode.class ); diff --git a/core/src/main/java/org/hibernate/hql/ast/util/SessionFactoryHelper.java b/core/src/main/java/org/hibernate/hql/ast/util/SessionFactoryHelper.java index f1daae2f3e..0455e1eb82 100644 --- a/core/src/main/java/org/hibernate/hql/ast/util/SessionFactoryHelper.java +++ b/core/src/main/java/org/hibernate/hql/ast/util/SessionFactoryHelper.java @@ -388,17 +388,19 @@ public class SessionFactoryHelper { * @return the function return type given the function name and the first argument expression node. */ public Type findFunctionReturnType(String functionName, AST first) { - // locate the registered function by the given name SQLFunction sqlFunction = requireSQLFunction( functionName ); + return findFunctionReturnType( functionName, sqlFunction, first ); + } + public Type findFunctionReturnType(String functionName, SQLFunction sqlFunction, AST firstArgument) { // determine the type of the first argument... Type argumentType = null; - if ( first != null ) { + if ( firstArgument != null ) { if ( "cast".equals(functionName) ) { - argumentType = sfi.getTypeResolver().heuristicType( first.getNextSibling().getText() ); + argumentType = sfi.getTypeResolver().heuristicType( firstArgument.getNextSibling().getText() ); } - else if ( first instanceof SqlNode ) { - argumentType = ( (SqlNode) first ).getDataType(); + else if ( SqlNode.class.isInstance( firstArgument ) ) { + argumentType = ( (SqlNode) firstArgument ).getDataType(); } } diff --git a/testsuite/src/test/java/org/hibernate/test/hql/ASTParserLoadingTest.java b/testsuite/src/test/java/org/hibernate/test/hql/ASTParserLoadingTest.java index 45c2f229ff..37678811ac 100644 --- a/testsuite/src/test/java/org/hibernate/test/hql/ASTParserLoadingTest.java +++ b/testsuite/src/test/java/org/hibernate/test/hql/ASTParserLoadingTest.java @@ -1567,7 +1567,7 @@ public class ASTParserLoadingTest extends FunctionalTestCase { public void testAggregation() { Session s = openSession(); - Transaction t = s.beginTransaction(); + s.beginTransaction(); Human h = new Human(); h.setBodyWeight( (float) 74.0 ); h.setHeightInches(120.5); @@ -1580,8 +1580,38 @@ public class ASTParserLoadingTest extends FunctionalTestCase { assertEquals(sum.floatValue(), 74.0, 0.01); assertEquals(avg.doubleValue(), 120.5, 0.01); Long id = (Long) s.createQuery("select max(a.id) from Animal a").uniqueResult(); + assertNotNull( id ); + s.delete( h ); + s.getTransaction().commit(); + s.close(); + + s = openSession(); + s.beginTransaction(); + h = new Human(); + h.setFloatValue( 2.5F ); + h.setIntValue( 1 ); + s.persist( h ); + Human h2 = new Human(); + h2.setFloatValue( 2.5F ); + h2.setIntValue( 2 ); + s.persist( h2 ); + Object[] results = (Object[]) s.createQuery( "select sum(h.floatValue), avg(h.floatValue), sum(h.intValue), avg(h.intValue) from Human h" ) + .uniqueResult(); + // spec says sum() on a float or double value should result in double + assertTrue( Double.class.isInstance( results[0] ) ); + assertEquals( 5D, results[0] ); + // avg() should return a double + assertTrue( Double.class.isInstance( results[1] ) ); + assertEquals( 2.5D, results[1] ); + // spec says sum() on short, int or long should result in long + assertTrue( Long.class.isInstance( results[2] ) ); + assertEquals( 3L, results[2] ); + // avg() should return a double + assertTrue( Double.class.isInstance( results[3] ) ); + assertEquals( 1.5D, results[3] ); s.delete(h); - t.commit(); + s.delete(h2); + s.getTransaction().commit(); s.close(); }