HHH-5173 - hql - average returns double but looses the decimal part

git-svn-id: https://svn.jboss.org/repos/hibernate/core/trunk@19453 1b8cb986-b30d-0410-93ca-fae66ebed9b2
This commit is contained in:
Steve Ebersole 2010-05-10 19:41:53 +00:00
parent 14bdaec5e7
commit 77fba4df70
18 changed files with 441 additions and 137 deletions

View File

@ -197,6 +197,8 @@ tokens
protected void processFunction(AST functionCall,boolean inSelect) throws SemanticException { } protected void processFunction(AST functionCall,boolean inSelect) throws SemanticException { }
protected void processAggregation(AST node, boolean inSelect) throws SemanticException { }
protected void processConstructor(AST constructor) throws SemanticException { } protected void processConstructor(AST constructor) throws SemanticException { }
protected AST generateNamedParameter(AST delimiterNode, AST nameNode) throws SemanticException { protected AST generateNamedParameter(AST delimiterNode, AST nameNode) throws SemanticException {

View File

@ -406,10 +406,13 @@ caseExpr
; ;
aggregate aggregate
: #(a:AGGREGATE { out(a); out("("); } expr { out(")"); } ) : #(
a:AGGREGATE { beginFunctionTemplate( a, a ); }
expr
{ endFunctionTemplate( a ); }
)
; ;
methodCall methodCall
: #(m:METHOD_CALL i:METHOD_NAME { beginFunctionTemplate(m,i); } : #(m:METHOD_CALL i:METHOD_NAME { beginFunctionTemplate(m,i); }
( #(EXPR_LIST (arguments)? ) )? ( #(EXPR_LIST (arguments)? ) )?

View File

@ -31,6 +31,7 @@ import java.sql.Types;
import org.hibernate.Hibernate; import org.hibernate.Hibernate;
import org.hibernate.cfg.Environment; import org.hibernate.cfg.Environment;
import org.hibernate.dialect.function.AvgWithArgumentCastFunction;
import org.hibernate.dialect.function.NoArgSQLFunction; import org.hibernate.dialect.function.NoArgSQLFunction;
import org.hibernate.dialect.function.SQLFunctionTemplate; import org.hibernate.dialect.function.SQLFunctionTemplate;
import org.hibernate.dialect.function.StandardSQLFunction; import org.hibernate.dialect.function.StandardSQLFunction;
@ -64,6 +65,8 @@ public class DB2Dialect extends Dialect {
registerColumnType( Types.LONGVARCHAR, "long varchar" ); registerColumnType( Types.LONGVARCHAR, "long varchar" );
registerColumnType( Types.LONGVARBINARY, "long varchar for bit data" ); registerColumnType( Types.LONGVARBINARY, "long varchar for bit data" );
registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) );
registerFunction("abs", new StandardSQLFunction("abs") ); registerFunction("abs", new StandardSQLFunction("abs") );
registerFunction("absval", new StandardSQLFunction("absval") ); registerFunction("absval", new StandardSQLFunction("absval") );
registerFunction("sign", new StandardSQLFunction("sign", Hibernate.INTEGER) ); registerFunction("sign", new StandardSQLFunction("sign", Hibernate.INTEGER) );

View File

@ -44,6 +44,7 @@ import org.hibernate.MappingException;
import org.hibernate.QueryException; import org.hibernate.QueryException;
import org.hibernate.LockOptions; import org.hibernate.LockOptions;
import org.hibernate.cfg.Environment; import org.hibernate.cfg.Environment;
import org.hibernate.dialect.function.AvgFunction;
import org.hibernate.dialect.function.CastFunction; import org.hibernate.dialect.function.CastFunction;
import org.hibernate.dialect.function.SQLFunction; import org.hibernate.dialect.function.SQLFunction;
import org.hibernate.dialect.function.SQLFunctionTemplate; import org.hibernate.dialect.function.SQLFunctionTemplate;
@ -128,24 +129,7 @@ public abstract class Dialect {
} }
); );
STANDARD_AGGREGATE_FUNCTIONS.put( STANDARD_AGGREGATE_FUNCTIONS.put( "avg", new AvgFunction() );
"avg",
new StandardSQLFunction("avg") {
public Type getReturnType(Type columnType, Mapping mapping) throws QueryException {
int[] sqlTypes;
try {
sqlTypes = columnType.sqlTypes( mapping );
}
catch ( MappingException me ) {
throw new QueryException( me );
}
if ( sqlTypes.length != 1 ) {
throw new QueryException( "multi-column type in avg()" );
}
return Hibernate.DOUBLE;
}
}
);
STANDARD_AGGREGATE_FUNCTIONS.put( "max", new StandardSQLFunction("max") ); STANDARD_AGGREGATE_FUNCTIONS.put( "max", new StandardSQLFunction("max") );
STANDARD_AGGREGATE_FUNCTIONS.put( "min", new StandardSQLFunction("min") ); STANDARD_AGGREGATE_FUNCTIONS.put( "min", new StandardSQLFunction("min") );

View File

@ -28,6 +28,7 @@ import java.sql.Types;
import org.hibernate.Hibernate; import org.hibernate.Hibernate;
import org.hibernate.cfg.Environment; import org.hibernate.cfg.Environment;
import org.hibernate.dialect.function.AvgWithArgumentCastFunction;
import org.hibernate.dialect.function.NoArgSQLFunction; import org.hibernate.dialect.function.NoArgSQLFunction;
import org.hibernate.dialect.function.StandardSQLFunction; import org.hibernate.dialect.function.StandardSQLFunction;
import org.hibernate.dialect.function.VarArgsSQLFunction; import org.hibernate.dialect.function.VarArgsSQLFunction;
@ -84,6 +85,9 @@ public class H2Dialect extends Dialect {
registerColumnType( Types.BLOB, "blob" ); registerColumnType( Types.BLOB, "blob" );
registerColumnType( Types.CLOB, "clob" ); registerColumnType( Types.CLOB, "clob" );
// Aggregations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) );
// select topic, syntax from information_schema.help // select topic, syntax from information_schema.help
// where section like 'Function%' order by section, topic // where section like 'Function%' order by section, topic
// //

View File

@ -32,6 +32,7 @@ import org.hibernate.Hibernate;
import org.hibernate.LockMode; import org.hibernate.LockMode;
import org.hibernate.StaleObjectStateException; import org.hibernate.StaleObjectStateException;
import org.hibernate.JDBCException; import org.hibernate.JDBCException;
import org.hibernate.dialect.function.AvgWithArgumentCastFunction;
import org.hibernate.engine.SessionImplementor; import org.hibernate.engine.SessionImplementor;
import org.hibernate.persister.entity.Lockable; import org.hibernate.persister.entity.Lockable;
import org.hibernate.cfg.Environment; import org.hibernate.cfg.Environment;
@ -83,6 +84,8 @@ public class HSQLDialect extends Dialect {
registerColumnType( Types.LONGVARBINARY, "longvarbinary" ); registerColumnType( Types.LONGVARBINARY, "longvarbinary" );
registerColumnType( Types.LONGVARCHAR, "longvarchar" ); registerColumnType( Types.LONGVARCHAR, "longvarchar" );
registerFunction( "avg", new AvgWithArgumentCastFunction( "double" ) );
registerFunction( "ascii", new StandardSQLFunction( "ascii", Hibernate.INTEGER ) ); registerFunction( "ascii", new StandardSQLFunction( "ascii", Hibernate.INTEGER ) );
registerFunction( "char", new StandardSQLFunction( "char", Hibernate.CHARACTER ) ); registerFunction( "char", new StandardSQLFunction( "char", Hibernate.CHARACTER ) );
registerFunction( "lower", new StandardSQLFunction( "lower" ) ); registerFunction( "lower", new StandardSQLFunction( "lower" ) );

View File

@ -0,0 +1,71 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.dialect.function;
import java.util.List;
import org.hibernate.MappingException;
import org.hibernate.QueryException;
import org.hibernate.engine.Mapping;
import org.hibernate.engine.SessionFactoryImplementor;
import org.hibernate.type.DoubleType;
import org.hibernate.type.Type;
/**
* The basic JPA spec compliant definition po<tt>AVG</tt> aggregation function.
*
* @author Steve Ebersole
*/
public class AvgFunction implements SQLFunction {
public final Type getReturnType(Type columnType, Mapping mapping) throws QueryException {
int[] sqlTypes;
try {
sqlTypes = columnType.sqlTypes( mapping );
}
catch ( MappingException me ) {
throw new QueryException( me );
}
if ( sqlTypes.length != 1 ) {
throw new QueryException( "multiple-column type in avg()" );
}
return DoubleType.INSTANCE;
}
public final boolean hasArguments() {
return true;
}
public final boolean hasParenthesesIfNoArguments() {
return true;
}
public String render(List args, SessionFactoryImplementor factory) throws QueryException {
return "avg(" + args.get( 0 ) + ")";
}
@Override
public final String toString() {
return "avg";
}
}

View File

@ -0,0 +1,54 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.dialect.function;
import java.util.List;
import org.hibernate.QueryException;
import org.hibernate.engine.SessionFactoryImplementor;
/**
* Some databases strictly return the type of the of the aggregation value for <tt>AVG</tt> which is
* problematic in the case of averaging integers because the decimals will be dropped. The usual workaround
* is to cast the integer argument as some form of double/decimal.
* <p/>
* A downside to this approach is that we always wrap the avg() argument in a cast even though we may not need or want
* to. A more full-featured solution would be defining {@link SQLFunction} such that we render based on the first
* argument; essentially have {@link SQLFunction} describe the basic metadata about the function and merge the
* {@link SQLFunction#getReturnType} and {@link SQLFunction#render} methods into a
*
* @author Steve Ebersole
*/
public class AvgWithArgumentCastFunction extends AvgFunction {
private final TemplateRenderer renderer;
public AvgWithArgumentCastFunction(String castType) {
renderer = new TemplateRenderer( "avg(cast(?1 as " + castType + "))" );
}
@Override
public String render(List args, SessionFactoryImplementor factory) throws QueryException {
return renderer.render( args, factory );
}
}

View File

@ -42,18 +42,6 @@ import org.hibernate.type.Type;
* @author David Channon * @author David Channon
*/ */
public interface SQLFunction { public interface SQLFunction {
/**
* The return type of the function. May be either a concrete type which
* is preset, or variable depending upon the type of the first function
* argument.
*
* @param columnType the type of the first argument
* @param mapping The mapping source.
* @return The type to be expected as a return.
* @throws org.hibernate.QueryException Indicates an issue resolving the return type.
*/
public Type getReturnType(Type columnType, Mapping mapping) throws QueryException;
/** /**
* Does this function have any arguments? * Does this function have any arguments?
* *
@ -68,14 +56,34 @@ public interface SQLFunction {
*/ */
public boolean hasParenthesesIfNoArguments(); public boolean hasParenthesesIfNoArguments();
/**
* The return type of the function. May be either a concrete type which
* is preset, or variable depending upon the type of the first function
* argument.
*
* @param columnType the type of the first argument
* @param mapping The mapping source.
*
* @return The type to be expected as a return.
*
* @throws org.hibernate.QueryException Indicates an issue resolving the return type.
*
* @deprecated See http://opensource.atlassian.com/projects/hibernate/browse/HHH-5212
*/
public Type getReturnType(Type columnType, Mapping mapping) throws QueryException;
/** /**
* Render the function call as SQL fragment. * Render the function call as SQL fragment.
* *
* @param args The function arguments * @param args The function arguments
* @param factory The SessionFactory * @param factory The SessionFactory
*
* @return The rendered function call * @return The rendered function call
*
* @throws org.hibernate.QueryException Indicates a problem rendering the * @throws org.hibernate.QueryException Indicates a problem rendering the
* function call. * function call.
*
* @deprecated See http://opensource.atlassian.com/projects/hibernate/browse/HHH-5212
*/ */
public String render(List args, SessionFactoryImplementor factory) throws QueryException; public String render(List args, SessionFactoryImplementor factory) throws QueryException;
} }

View File

@ -1,10 +1,10 @@
/* /*
* Hibernate, Relational Persistence for Idiomatic Java * Hibernate, Relational Persistence for Idiomatic Java
* *
* Copyright (c) 2008, Red Hat Middleware LLC or third-party contributors as * Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution * indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are * statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Middleware LLC. * distributed under license by Red Hat Inc.
* *
* This copyrighted material is made available to anyone wishing to use, modify, * This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU * copy, or redistribute it subject to the terms and conditions of the GNU
@ -20,7 +20,6 @@
* Free Software Foundation, Inc. * Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor * 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA * Boston, MA 02110-1301 USA
*
*/ */
package org.hibernate.dialect.function; package org.hibernate.dialect.function;
@ -29,7 +28,6 @@ import org.hibernate.engine.Mapping;
import org.hibernate.engine.SessionFactoryImplementor; import org.hibernate.engine.SessionFactoryImplementor;
import org.hibernate.type.Type; import org.hibernate.type.Type;
import java.util.ArrayList;
import java.util.List; import java.util.List;
/** /**
@ -45,102 +43,51 @@ import java.util.List;
*/ */
public class SQLFunctionTemplate implements SQLFunction { public class SQLFunctionTemplate implements SQLFunction {
private final Type type; private final Type type;
private final boolean hasArguments; private final TemplateRenderer renderer;
private final boolean hasParenthesesIfNoArgs; private final boolean hasParenthesesIfNoArgs;
private final String template;
private final String[] chunks;
private final int[] paramIndexes;
public SQLFunctionTemplate(Type type, String template) { public SQLFunctionTemplate(Type type, String template) {
this( type, template, true ); this( type, template, true );
} }
public SQLFunctionTemplate(Type type, String template, boolean hasParenthesesIfNoArgs) { public SQLFunctionTemplate(Type type, String template, boolean hasParenthesesIfNoArgs) {
this.type = type; this.type = type;
this.template = template; this.renderer = new TemplateRenderer( template );
List chunkList = new ArrayList();
List paramList = new ArrayList();
StringBuffer chunk = new StringBuffer( 10 );
StringBuffer index = new StringBuffer( 2 );
for ( int i = 0; i < template.length(); ++i ) {
char c = template.charAt( i );
if ( c == '?' ) {
chunkList.add( chunk.toString() );
chunk.delete( 0, chunk.length() );
while ( ++i < template.length() ) {
c = template.charAt( i );
if ( Character.isDigit( c ) ) {
index.append( c );
}
else {
chunk.append( c );
break;
}
}
paramList.add( new Integer( Integer.parseInt( index.toString() ) - 1 ) );
index.delete( 0, index.length() );
}
else {
chunk.append( c );
}
}
if ( chunk.length() > 0 ) {
chunkList.add( chunk.toString() );
}
chunks = ( String[] ) chunkList.toArray( new String[chunkList.size()] );
paramIndexes = new int[paramList.size()];
for ( int i = 0; i < paramIndexes.length; ++i ) {
paramIndexes[i] = ( ( Integer ) paramList.get( i ) ).intValue();
}
hasArguments = paramIndexes.length > 0;
this.hasParenthesesIfNoArgs = hasParenthesesIfNoArgs; this.hasParenthesesIfNoArgs = hasParenthesesIfNoArgs;
} }
/** /**
* Applies the template to passed in arguments. * {@inheritDoc}
* @param args function arguments
*
* @return generated SQL function call
*/ */
public String render(List args, SessionFactoryImplementor factory) { public String render(List args, SessionFactoryImplementor factory) {
StringBuffer buf = new StringBuffer(); return renderer.render( args, factory );
for ( int i = 0; i < chunks.length; ++i ) {
if ( i < paramIndexes.length ) {
Object arg = paramIndexes[i] < args.size() ? args.get( paramIndexes[i] ) : null;
if ( arg != null ) {
buf.append( chunks[i] ).append( arg );
}
}
else {
buf.append( chunks[i] );
}
}
return buf.toString();
} }
// SQLFunction implementation /**
* {@inheritDoc}
*/
public Type getReturnType(Type columnType, Mapping mapping) throws QueryException { public Type getReturnType(Type columnType, Mapping mapping) throws QueryException {
return type; return type;
} }
/**
* {@inheritDoc}
*/
public boolean hasArguments() { public boolean hasArguments() {
return hasArguments; return renderer.getAnticipatedNumberOfArguments() > 0;
} }
/**
* {@inheritDoc}
*/
public boolean hasParenthesesIfNoArguments() { public boolean hasParenthesesIfNoArguments() {
return hasParenthesesIfNoArgs; return hasParenthesesIfNoArgs;
} }
/**
* {@inheritDoc}
*/
public String toString() { public String toString() {
return template; return renderer.getTemplate();
} }
} }

View File

@ -0,0 +1,121 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.dialect.function;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.hibernate.engine.SessionFactoryImplementor;
/**
* Delegate for handling function "templates".
*
* @author Steve Ebersole
*/
public class TemplateRenderer {
private static final Logger log = LoggerFactory.getLogger( TemplateRenderer.class );
private final String template;
private final String[] chunks;
private final int[] paramIndexes;
@SuppressWarnings({ "UnnecessaryUnboxing" })
public TemplateRenderer(String template) {
this.template = template;
List<String> chunkList = new ArrayList<String>();
List<Integer> paramList = new ArrayList<Integer>();
StringBuffer chunk = new StringBuffer( 10 );
StringBuffer index = new StringBuffer( 2 );
for ( int i = 0; i < template.length(); ++i ) {
char c = template.charAt( i );
if ( c == '?' ) {
chunkList.add( chunk.toString() );
chunk.delete( 0, chunk.length() );
while ( ++i < template.length() ) {
c = template.charAt( i );
if ( Character.isDigit( c ) ) {
index.append( c );
}
else {
chunk.append( c );
break;
}
}
paramList.add( Integer.valueOf( index.toString() ) );
index.delete( 0, index.length() );
}
else {
chunk.append( c );
}
}
if ( chunk.length() > 0 ) {
chunkList.add( chunk.toString() );
}
chunks = chunkList.toArray( new String[chunkList.size()] );
paramIndexes = new int[paramList.size()];
for ( int i = 0; i < paramIndexes.length; ++i ) {
paramIndexes[i] = paramList.get( i ).intValue();
}
}
public String getTemplate() {
return template;
}
public int getAnticipatedNumberOfArguments() {
return paramIndexes.length;
}
@SuppressWarnings({ "UnusedDeclaration" })
public String render(List args, SessionFactoryImplementor factory) {
int numberOfArguments = args.size();
if ( getAnticipatedNumberOfArguments() > 0 && numberOfArguments != getAnticipatedNumberOfArguments() ) {
log.warn( "Function template anticipated {} arguments, but {} arguments encountered",
getAnticipatedNumberOfArguments(), numberOfArguments );
}
StringBuffer buf = new StringBuffer();
for ( int i = 0; i < chunks.length; ++i ) {
if ( i < paramIndexes.length ) {
final int index = paramIndexes[i] - 1;
final Object arg = index < numberOfArguments ? args.get( index ) : null;
if ( arg != null ) {
buf.append( chunks[i] ).append( arg );
}
}
else {
buf.append( chunks[i] );
}
}
return buf.toString();
}
}

View File

@ -47,6 +47,7 @@ import org.hibernate.hql.antlr.HqlSqlBaseWalker;
import org.hibernate.hql.antlr.HqlSqlTokenTypes; import org.hibernate.hql.antlr.HqlSqlTokenTypes;
import org.hibernate.hql.antlr.HqlTokenTypes; import org.hibernate.hql.antlr.HqlTokenTypes;
import org.hibernate.hql.antlr.SqlTokenTypes; import org.hibernate.hql.antlr.SqlTokenTypes;
import org.hibernate.hql.ast.tree.AggregateNode;
import org.hibernate.hql.ast.tree.AssignmentSpecification; import org.hibernate.hql.ast.tree.AssignmentSpecification;
import org.hibernate.hql.ast.tree.CollectionFunction; import org.hibernate.hql.ast.tree.CollectionFunction;
import org.hibernate.hql.ast.tree.ConstructorNode; import org.hibernate.hql.ast.tree.ConstructorNode;
@ -977,6 +978,11 @@ public class HqlSqlWalker extends HqlSqlBaseWalker implements ErrorReporter, Par
methodNode.resolve( inSelect ); methodNode.resolve( inSelect );
} }
protected void processAggregation(AST node, boolean inSelect) throws SemanticException {
AggregateNode aggregateNode = ( AggregateNode ) node;
aggregateNode.resolve();
}
protected void processConstructor(AST constructor) throws SemanticException { protected void processConstructor(AST constructor) throws SemanticException {
ConstructorNode constructorNode = ( ConstructorNode ) constructor; ConstructorNode constructorNode = ( ConstructorNode ) constructor;
constructorNode.prepare(); constructorNode.prepare();

View File

@ -32,6 +32,7 @@ import java.util.Arrays;
import antlr.RecognitionException; import antlr.RecognitionException;
import antlr.collections.AST; import antlr.collections.AST;
import org.hibernate.QueryException; import org.hibernate.QueryException;
import org.hibernate.hql.ast.tree.FunctionNode;
import org.hibernate.util.StringHelper; import org.hibernate.util.StringHelper;
import org.hibernate.param.ParameterSpecification; import org.hibernate.param.ParameterSpecification;
import org.hibernate.dialect.function.SQLFunction; import org.hibernate.dialect.function.SQLFunction;
@ -71,7 +72,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
private ParseErrorHandler parseErrorHandler; private ParseErrorHandler parseErrorHandler;
private SessionFactoryImplementor sessionFactory; private SessionFactoryImplementor sessionFactory;
private LinkedList outputStack = new LinkedList(); private LinkedList<SqlWriter> outputStack = new LinkedList<SqlWriter>();
private final ASTPrinter printer = new ASTPrinter( SqlTokenTypes.class ); private final ASTPrinter printer = new ASTPrinter( SqlTokenTypes.class );
private List collectedParameters = new ArrayList(); private List collectedParameters = new ArrayList();
@ -178,31 +179,33 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
} }
} }
protected void beginFunctionTemplate(AST m, AST i) { protected void beginFunctionTemplate(AST node, AST nameNode) {
MethodNode methodNode = ( MethodNode ) m; // NOTE for AGGREGATE both nodes are the same; for METHOD the first is the METHOD, the second is the
SQLFunction template = methodNode.getSQLFunction(); // METHOD_NAME
if ( template == null ) { FunctionNode functionNode = ( FunctionNode ) node;
// if template is null we just write the function out as it appears in the hql statement SQLFunction sqlFunction = functionNode.getSQLFunction();
super.beginFunctionTemplate( m, i ); if ( sqlFunction == null ) {
// if SQLFunction is null we just write the function out as it appears in the hql statement
super.beginFunctionTemplate( node, nameNode );
} }
else { else {
// this function has a template -> redirect output and catch the arguments // this function has a registered SQLFunction -> redirect output and catch the arguments
outputStack.addFirst( writer ); outputStack.addFirst( writer );
writer = new FunctionArguments(); writer = new FunctionArguments();
} }
} }
protected void endFunctionTemplate(AST m) { protected void endFunctionTemplate(AST node) {
MethodNode methodNode = ( MethodNode ) m; FunctionNode functionNode = ( FunctionNode ) node;
SQLFunction template = methodNode.getSQLFunction(); SQLFunction sqlFunction = functionNode.getSQLFunction();
if ( template == null ) { if ( sqlFunction == null ) {
super.endFunctionTemplate( m ); super.endFunctionTemplate( node );
} }
else { else {
// this function has a template -> restore output, apply the template and write the result out // this function has a registered SQLFunction -> redirect output and catch the arguments
FunctionArguments functionArguments = ( FunctionArguments ) writer; // TODO: Downcast to avoid using an interface? Yuck. FunctionArguments functionArguments = ( FunctionArguments ) writer;
writer = ( SqlWriter ) outputStack.removeFirst(); writer = outputStack.removeFirst();
out( template.render( functionArguments.getArgs(), sessionFactory ) ); out( sqlFunction.render( functionArguments.getArgs(), sessionFactory ) );
} }
} }
@ -230,7 +233,7 @@ public class SqlGenerator extends SqlGeneratorBase implements ErrorReporter {
*/ */
class FunctionArguments implements SqlWriter { class FunctionArguments implements SqlWriter {
private int argInd; private int argInd;
private final List args = new ArrayList( 3 ); private final List<String> args = new ArrayList<String>(3);
public void clause(String clause) { public void clause(String clause) {
if ( argInd == args.size() ) { if ( argInd == args.size() ) {

View File

@ -1,10 +1,10 @@
/* /*
* Hibernate, Relational Persistence for Idiomatic Java * Hibernate, Relational Persistence for Idiomatic Java
* *
* Copyright (c) 2008, Red Hat Middleware LLC or third-party contributors as * Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution * indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are * statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Middleware LLC. * distributed under license by Red Hat Inc.
* *
* This copyrighted material is made available to anyone wishing to use, modify, * This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU * copy, or redistribute it subject to the terms and conditions of the GNU
@ -20,31 +20,59 @@
* Free Software Foundation, Inc. * Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor * 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA * Boston, MA 02110-1301 USA
*
*/ */
package org.hibernate.hql.ast.tree; package org.hibernate.hql.ast.tree;
import org.hibernate.dialect.function.SQLFunction;
import org.hibernate.dialect.function.StandardSQLFunction;
import org.hibernate.hql.ast.util.ColumnHelper; import org.hibernate.hql.ast.util.ColumnHelper;
import org.hibernate.type.Type; import org.hibernate.type.Type;
import antlr.SemanticException; import antlr.SemanticException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** /**
* Represents an aggregate function i.e. min, max, sum, avg. * Represents an aggregate function i.e. min, max, sum, avg.
* *
* @author Joshua Davis * @author Joshua Davis
*/ */
public class AggregateNode extends AbstractSelectExpression implements SelectExpression { public class AggregateNode extends AbstractSelectExpression implements SelectExpression, FunctionNode {
private static final Logger log = LoggerFactory.getLogger( AggregateNode.class );
public AggregateNode() { private SQLFunction sqlFunction;
public SQLFunction getSQLFunction() {
return sqlFunction;
}
public void resolve() {
resolveFunction();
}
private SQLFunction resolveFunction() {
if ( sqlFunction == null ) {
final String name = getText();
sqlFunction = getSessionFactoryHelper().findSQLFunction( getText() );
if ( sqlFunction == null ) {
log.info( "Could not resolve aggregate function {}; using standard definition", name );
sqlFunction = new StandardSQLFunction( name );
}
}
return sqlFunction;
} }
public Type getDataType() { public Type getDataType() {
// Get the function return value type, based on the type of the first argument. // Get the function return value type, based on the type of the first argument.
return getSessionFactoryHelper().findFunctionReturnType( getText(), getFirstChild() ); return getSessionFactoryHelper().findFunctionReturnType( getText(), resolveFunction(), getFirstChild() );
} }
public void setScalarColumnText(int i) throws SemanticException { public void setScalarColumnText(int i) throws SemanticException {
ColumnHelper.generateSingleScalarColumn( this, i ); ColumnHelper.generateSingleScalarColumn( this, i );
} }
public boolean isScalar() throws SemanticException {
// functions in a SELECT should always be considered scalar.
return true;
}
} }

View File

@ -0,0 +1,35 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* Copyright (c) 2010, Red Hat Inc. or third-party contributors as
* indicated by the @author tags or express copyright attribution
* statements applied by the authors. All third-party contributions are
* distributed under license by Red Hat Inc.
*
* This copyrighted material is made available to anyone wishing to use, modify,
* copy, or redistribute it subject to the terms and conditions of the GNU
* Lesser General Public License, as published by the Free Software Foundation.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License
* for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this distribution; if not, write to:
* Free Software Foundation, Inc.
* 51 Franklin Street, Fifth Floor
* Boston, MA 02110-1301 USA
*/
package org.hibernate.hql.ast.tree;
import org.hibernate.dialect.function.SQLFunction;
/**
* Identifies a node which models a SQL function.
*
* @author Steve Ebersole
*/
public interface FunctionNode {
public SQLFunction getSQLFunction();
}

View File

@ -47,7 +47,7 @@ import org.slf4j.LoggerFactory;
* *
* @author josh * @author josh
*/ */
public class MethodNode extends AbstractSelectExpression implements SelectExpression { public class MethodNode extends AbstractSelectExpression implements SelectExpression, FunctionNode {
private static final Logger log = LoggerFactory.getLogger( MethodNode.class ); private static final Logger log = LoggerFactory.getLogger( MethodNode.class );

View File

@ -388,17 +388,19 @@ public class SessionFactoryHelper {
* @return the function return type given the function name and the first argument expression node. * @return the function return type given the function name and the first argument expression node.
*/ */
public Type findFunctionReturnType(String functionName, AST first) { public Type findFunctionReturnType(String functionName, AST first) {
// locate the registered function by the given name
SQLFunction sqlFunction = requireSQLFunction( functionName ); SQLFunction sqlFunction = requireSQLFunction( functionName );
return findFunctionReturnType( functionName, sqlFunction, first );
}
public Type findFunctionReturnType(String functionName, SQLFunction sqlFunction, AST firstArgument) {
// determine the type of the first argument... // determine the type of the first argument...
Type argumentType = null; Type argumentType = null;
if ( first != null ) { if ( firstArgument != null ) {
if ( "cast".equals(functionName) ) { if ( "cast".equals(functionName) ) {
argumentType = sfi.getTypeResolver().heuristicType( first.getNextSibling().getText() ); argumentType = sfi.getTypeResolver().heuristicType( firstArgument.getNextSibling().getText() );
} }
else if ( first instanceof SqlNode ) { else if ( SqlNode.class.isInstance( firstArgument ) ) {
argumentType = ( (SqlNode) first ).getDataType(); argumentType = ( (SqlNode) firstArgument ).getDataType();
} }
} }

View File

@ -1567,7 +1567,7 @@ public class ASTParserLoadingTest extends FunctionalTestCase {
public void testAggregation() { public void testAggregation() {
Session s = openSession(); Session s = openSession();
Transaction t = s.beginTransaction(); s.beginTransaction();
Human h = new Human(); Human h = new Human();
h.setBodyWeight( (float) 74.0 ); h.setBodyWeight( (float) 74.0 );
h.setHeightInches(120.5); h.setHeightInches(120.5);
@ -1580,8 +1580,38 @@ public class ASTParserLoadingTest extends FunctionalTestCase {
assertEquals(sum.floatValue(), 74.0, 0.01); assertEquals(sum.floatValue(), 74.0, 0.01);
assertEquals(avg.doubleValue(), 120.5, 0.01); assertEquals(avg.doubleValue(), 120.5, 0.01);
Long id = (Long) s.createQuery("select max(a.id) from Animal a").uniqueResult(); Long id = (Long) s.createQuery("select max(a.id) from Animal a").uniqueResult();
assertNotNull( id );
s.delete( h );
s.getTransaction().commit();
s.close();
s = openSession();
s.beginTransaction();
h = new Human();
h.setFloatValue( 2.5F );
h.setIntValue( 1 );
s.persist( h );
Human h2 = new Human();
h2.setFloatValue( 2.5F );
h2.setIntValue( 2 );
s.persist( h2 );
Object[] results = (Object[]) s.createQuery( "select sum(h.floatValue), avg(h.floatValue), sum(h.intValue), avg(h.intValue) from Human h" )
.uniqueResult();
// spec says sum() on a float or double value should result in double
assertTrue( Double.class.isInstance( results[0] ) );
assertEquals( 5D, results[0] );
// avg() should return a double
assertTrue( Double.class.isInstance( results[1] ) );
assertEquals( 2.5D, results[1] );
// spec says sum() on short, int or long should result in long
assertTrue( Long.class.isInstance( results[2] ) );
assertEquals( 3L, results[2] );
// avg() should return a double
assertTrue( Double.class.isInstance( results[3] ) );
assertEquals( 1.5D, results[3] );
s.delete(h); s.delete(h);
t.commit(); s.delete(h2);
s.getTransaction().commit();
s.close(); s.close();
} }