HHH-11042 Implement tuple distinct count emulation

This commit is contained in:
Christian Beikov 2021-09-23 15:24:33 +02:00
parent 3ecc602852
commit 340c1b3f61
11 changed files with 434 additions and 59 deletions

View File

@ -8,7 +8,6 @@ package org.hibernate.dialect;
import org.hibernate.LockMode;
import org.hibernate.LockOptions;
import org.hibernate.dialect.function.CastStrEmulation;
import org.hibernate.dialect.function.TransactSQLStrFunction;
import org.hibernate.query.NullOrdering;
import org.hibernate.boot.TempTableDdlTransactionHandling;

View File

@ -52,6 +52,7 @@ import org.hibernate.sql.exec.spi.JdbcOperation;
import org.hibernate.tool.schema.extract.internal.SequenceInformationExtractorDerbyDatabaseImpl;
import org.hibernate.tool.schema.extract.internal.SequenceInformationExtractorNoOpImpl;
import org.hibernate.tool.schema.extract.spi.SequenceInformationExtractor;
import org.hibernate.type.StandardBasicTypes;
import org.hibernate.type.descriptor.jdbc.DecimalTypeDescriptor;
import org.hibernate.type.descriptor.jdbc.JdbcTypeDescriptor;
import org.hibernate.type.descriptor.jdbc.SmallIntTypeDescriptor;
@ -183,7 +184,13 @@ public class DerbyDialect extends Dialect {
super.initializeFunctionRegistry( queryEngine );
// Derby needs an actual argument type for aggregates like SUM, AVG, MIN, MAX to determine the result type
CommonFunctionFactory.aggregates( this, queryEngine, SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
CommonFunctionFactory.aggregates(
this,
queryEngine,
SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER,
"||",
getCastTypeName( StandardBasicTypes.STRING, null, null, null )
);
CommonFunctionFactory.concat_pipeOperator( queryEngine );
CommonFunctionFactory.cot( queryEngine );
@ -211,7 +218,7 @@ public class DerbyDialect extends Dialect {
.setExactArgumentCount( 2 )
.register();
queryEngine.getSqmFunctionRegistry().register( "concat", new DerbyConcatFunction() );
queryEngine.getSqmFunctionRegistry().register( "concat", new DerbyConcatFunction( this ) );
//no way I can see to pad with anything other than spaces
queryEngine.getSqmFunctionRegistry().register( "lpad", new DerbyLpadEmulation() );

View File

@ -3339,6 +3339,15 @@ public abstract class Dialect implements ConversionContext {
return false;
}
/**
* If {@link #supportsTupleCounts()} is true, does the Dialect require the tuple to be wrapped with parens?
*
* @return boolean
*/
public boolean requiresParensForTupleCounts() {
return supportsTupleCounts();
}
/**
* Does this dialect support `count(distinct a,b)`?
*

View File

@ -494,11 +494,6 @@ public class MySQLDialect extends Dialect {
return getMySQLVersion() >= 570;
}
@Override
public boolean supportsTupleCounts() {
return true;
}
@Override
public boolean supportsUnionAll() {
return getMySQLVersion() >= 500;

View File

@ -33,6 +33,7 @@ import org.hibernate.query.FetchClauseType;
import org.hibernate.query.NullPrecedence;
import org.hibernate.query.TemporalUnit;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.SqlAstTranslatorFactory;
import org.hibernate.sql.ast.spi.StandardSqlAstTranslatorFactory;
@ -147,6 +148,9 @@ public class SQLServerDialect extends AbstractTransactSQLDialect {
public void initializeFunctionRegistry(QueryEngine queryEngine) {
super.initializeFunctionRegistry(queryEngine);
// For SQL-Server we need to cast certain arguments to varchar(max) to be able to concat them
CommonFunctionFactory.aggregates( this, queryEngine, SqlAstNodeRenderingMode.DEFAULT, "+", "varchar(max)" );
CommonFunctionFactory.truncate_round( queryEngine );
CommonFunctionFactory.everyAny_sumIif( queryEngine );
CommonFunctionFactory.bitLength_pattern( queryEngine, "datalength(?1) * 8" );

View File

@ -9,6 +9,7 @@ package org.hibernate.dialect;
import org.hibernate.NotYetImplementedFor6Exception;
import org.hibernate.boot.model.TypeContributions;
import org.hibernate.dialect.function.CommonFunctionFactory;
import org.hibernate.dialect.function.SybaseConcatFunction;
import org.hibernate.engine.jdbc.dialect.spi.DialectResolutionInfo;
import org.hibernate.engine.jdbc.env.spi.IdentifierCaseStrategy;
import org.hibernate.engine.jdbc.env.spi.IdentifierHelper;
@ -28,6 +29,7 @@ import org.hibernate.query.sqm.sql.SqmTranslatorFactory;
import org.hibernate.query.sqm.sql.StandardSqmTranslatorFactory;
import org.hibernate.query.sqm.tree.select.SqmSelectStatement;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.SqlAstTranslatorFactory;
import org.hibernate.sql.ast.spi.SqlAstCreationContext;
@ -208,6 +210,11 @@ public class SybaseDialect extends AbstractTransactSQLDialect {
public void initializeFunctionRegistry(QueryEngine queryEngine) {
super.initializeFunctionRegistry(queryEngine);
// For SQL-Server we need to cast certain arguments to varchar(16384) to be able to concat them
CommonFunctionFactory.aggregates( this, queryEngine, SqlAstNodeRenderingMode.DEFAULT, "+", "varchar(16384)" );
queryEngine.getSqmFunctionRegistry().register( "concat", new SybaseConcatFunction( this ) );
//this doesn't work 100% on earlier versions of Sybase
//which were missing the third parameter in charindex()
//TODO: we could emulate it with substring() like in Postgres

View File

@ -17,9 +17,7 @@ import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.model.domain.AllowableFunctionReturnType;
import org.hibernate.metamodel.mapping.BasicValuedMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.model.domain.AllowableFunctionReturnType;
import org.hibernate.query.spi.QueryEngine;
import org.hibernate.query.sqm.produce.function.FunctionReturnTypeResolver;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
@ -1488,6 +1486,15 @@ public class CommonFunctionFactory {
Dialect dialect,
QueryEngine queryEngine,
SqlAstNodeRenderingMode inferenceArgumentRenderingMode) {
aggregates( dialect, queryEngine, inferenceArgumentRenderingMode, "||", null );
}
public static void aggregates(
Dialect dialect,
QueryEngine queryEngine,
SqlAstNodeRenderingMode inferenceArgumentRenderingMode,
String concatOperator,
String concatArgumentCastType) {
queryEngine.getSqmFunctionRegistry().namedAggregateDescriptorBuilder( "max" )
.setArgumentRenderingMode( inferenceArgumentRenderingMode )
.setExactArgumentCount( 1 )
@ -1607,7 +1614,10 @@ public class CommonFunctionFactory {
.setExactArgumentCount( 1 )
.register();
queryEngine.getSqmFunctionRegistry().register( CountFunction.FUNCTION_NAME, new CountFunction( dialect ) );
queryEngine.getSqmFunctionRegistry().register(
CountFunction.FUNCTION_NAME,
new CountFunction( dialect, concatOperator, concatArgumentCastType )
);
}
public static void math(QueryEngine queryEngine) {

View File

@ -6,32 +6,33 @@
*/
package org.hibernate.dialect.function;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.EntityIdentifierMapping;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.metamodel.mapping.JdbcMappingContainer;
import org.hibernate.query.CastType;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.function.FunctionKind;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.query.sqm.sql.internal.AbstractSqmPathInterpretation;
import org.hibernate.query.sqm.sql.internal.EntityValuedPathInterpretation;
import org.hibernate.query.sqm.sql.internal.NonAggregatedCompositeValuedPathInterpretation;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.CaseSearchedExpression;
import org.hibernate.sql.ast.tree.expression.Distinct;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.sql.ast.tree.expression.NullnessLiteral;
import org.hibernate.sql.ast.tree.expression.QueryLiteral;
import org.hibernate.sql.ast.tree.expression.SqlTuple;
import org.hibernate.sql.ast.tree.expression.SqlTupleContainer;
import org.hibernate.sql.ast.tree.expression.Star;
import org.hibernate.sql.ast.tree.predicate.Junction;
import org.hibernate.sql.ast.tree.predicate.NullnessPredicate;
import org.hibernate.sql.ast.tree.from.TableGroup;
import org.hibernate.sql.ast.tree.from.TableGroupJoin;
import org.hibernate.sql.ast.tree.predicate.Predicate;
import org.hibernate.sql.ast.tree.select.QuerySpec;
import org.hibernate.type.StandardBasicTypes;
/**
@ -41,8 +42,10 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public static final String FUNCTION_NAME = "count";
private final Dialect dialect;
private final String concatOperator;
private final String concatArgumentCastType;
public CountFunction(Dialect dialect) {
public CountFunction(Dialect dialect, String concatOperator, String concatArgumentCastType) {
super(
FUNCTION_NAME,
FunctionKind.AGGREGATE,
@ -50,6 +53,8 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.LONG )
);
this.dialect = dialect;
this.concatOperator = concatOperator;
this.concatArgumentCastType = concatArgumentCastType;
}
@Override
@ -65,50 +70,176 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
SqlAstTranslator<?> translator) {
final boolean caseWrapper = filter != null && !translator.supportsFilterClause();
final SqlAstNode arg = sqlAstArguments.get( 0 );
final SqlAstNode realArg;
sqlAppender.appendSql( "count(" );
final SqlTuple tuple;
if ( arg instanceof Distinct ) {
sqlAppender.appendSql( "distinct " );
final Expression distinctArg = ( (Distinct) arg ).getExpression();
// todo (6.0): emulate tuple count distinct if necessary
realArg = distinctArg;
}
else {
// If the table group supports inner joins, this means that it is non-optional,
// which means we can omit the tuples and instead use count(*)
final SqlTuple tuple;
if ( ( arg instanceof EntityValuedPathInterpretation<?> || arg instanceof NonAggregatedCompositeValuedPathInterpretation<?> )
&& ( (AbstractSqmPathInterpretation<?>) arg ).getTableGroup().canUseInnerJoins() ) {
realArg = Star.INSTANCE;
}
else if ( !dialect.supportsTupleCounts() && ( tuple = SqlTupleContainer.getSqlTuple( arg ) ) != null ) {
if ( ( tuple = SqlTupleContainer.getSqlTuple( distinctArg ) ) != null ) {
final List<? extends Expression> expressions = tuple.getExpressions();
// Single element tuple
if ( expressions.size() == 1 ) {
realArg = expressions.get( 0 );
renderSimpleArgument( sqlAppender, filter, translator, caseWrapper, expressions.get( 0 ) );
}
// Emulate tuple distinct count
else if ( !dialect.supportsTupleDistinctCounts() ) {
// see https://hibernate.atlassian.net/browse/HHH-11042 details about this implementation
// The idea is to concat all tuple elements, separated by a character that can't appear in the string
// We choose to map this to the NUL character i.e. \0 with the ASCII code 0
// To avoid collisions we must take special care of SQL NULL and empty strings,
// which is why we map them to a special sequence:
// NULL -> \0
// '' -> \0 + argumentNumber
// In the end, the expression looks like the following:
// count(distinct coalesce(nullif(coalesce(col1 || '', '\0'), ''), '\01') || '\0' || coalesce(nullif(coalesce(col2 || '', '\0'), ''), '\02'))
if ( caseWrapper ) {
sqlAppender.appendSql( "case when " );
filter.accept( translator );
sqlAppender.appendSql( " then " );
}
sqlAppender.appendSql( "coalesce(nullif(coalesce(" );
renderCastedArgument( sqlAppender, translator, expressions.get( 0 ) );
int argumentNumber = 1;
for ( int i = 1; i < expressions.size(); i++, argumentNumber++ ) {
// Concat with empty string to get implicit conversion
sqlAppender.appendSql( concatOperator );
sqlAppender.appendSql( "''" );
sqlAppender.appendSql( ",'\\0'),''),'\\0" );
sqlAppender.appendSql( Integer.toString( argumentNumber ) );
sqlAppender.appendSql( "')" );
sqlAppender.appendSql( concatOperator );
sqlAppender.appendSql( "'\\0'" );
sqlAppender.appendSql( concatOperator );
sqlAppender.appendSql( "coalesce(nullif(coalesce(" );
renderCastedArgument( sqlAppender, translator, expressions.get( i ) );
}
// Concat with empty string to get implicit conversion
sqlAppender.appendSql( concatOperator );
sqlAppender.appendSql( "''" );
sqlAppender.appendSql( ",'\\0'),''),'\\0" );
sqlAppender.appendSql( Integer.toString( argumentNumber ) );
sqlAppender.appendSql( "')" );
if ( caseWrapper ) {
sqlAppender.appendSql( " else null end" );
}
}
else {
final List<CaseSearchedExpression.WhenFragment> whenFragments = new ArrayList<>( 1 );
final Junction junction = new Junction( Junction.Nature.DISJUNCTION );
for ( Expression expression : expressions ) {
junction.add( new NullnessPredicate( expression ) );
}
whenFragments.add(
new CaseSearchedExpression.WhenFragment(
junction,
new NullnessLiteral( StandardBasicTypes.INTEGER )
)
);
realArg = new CaseSearchedExpression(
StandardBasicTypes.INTEGER,
whenFragments,
new QueryLiteral<>( 1, StandardBasicTypes.INTEGER )
renderTupleCountSupported(
sqlAppender,
filter,
translator,
caseWrapper,
tuple,
expressions,
dialect.requiresParensForTupleDistinctCounts()
);
}
}
else {
realArg = arg;
renderSimpleArgument( sqlAppender, filter, translator, caseWrapper, distinctArg );
}
}
else {
if ( canReplaceWithStar( arg, translator ) ) {
renderSimpleArgument( sqlAppender, filter, translator, caseWrapper, Star.INSTANCE );
}
else if ( ( tuple = SqlTupleContainer.getSqlTuple( arg ) ) != null ) {
final List<? extends Expression> expressions = tuple.getExpressions();
// Single element tuple
if ( expressions.size() == 1 ) {
renderSimpleArgument( sqlAppender, filter, translator, caseWrapper, expressions.get( 0 ) );
}
// Emulate the tuple count with a case when expression
else if ( !dialect.supportsTupleCounts() ) {
sqlAppender.appendSql( "case when " );
if ( caseWrapper ) {
filter.accept( translator );
sqlAppender.appendSql( " and " );
}
translator.render( expressions.get( 0 ), SqlAstNodeRenderingMode.DEFAULT );
sqlAppender.appendSql( " is not null" );
for ( int i = 1; i < expressions.size(); i++ ) {
sqlAppender.appendSql( " and " );
translator.render( expressions.get( i ), SqlAstNodeRenderingMode.DEFAULT );
sqlAppender.appendSql( " is not null" );
}
sqlAppender.appendSql( " then 1 else null end" );
}
// Tuple counts are supported
else {
renderTupleCountSupported(
sqlAppender,
filter,
translator,
caseWrapper,
tuple,
expressions,
dialect.requiresParensForTupleCounts()
);
}
}
else {
renderSimpleArgument( sqlAppender, filter, translator, caseWrapper, arg );
}
}
sqlAppender.appendSql( ')' );
if ( filter != null && !caseWrapper ) {
sqlAppender.appendSql( " filter (where " );
filter.accept( translator );
sqlAppender.appendSql( ')' );
}
}
private void renderTupleCountSupported(
SqlAppender sqlAppender,
Predicate filter,
SqlAstTranslator<?> translator,
boolean caseWrapper,
SqlTuple tuple,
List<? extends Expression> expressions,
boolean requiresParenthesis) {
if ( caseWrapper ) {
// Add the case wrapper as first element instead of wrapping everything.
// Rendering "Star" will result in `case when FILTER then 1 else null end`
if ( requiresParenthesis ) {
sqlAppender.appendSql( '(' );
renderSimpleArgument( sqlAppender, filter, translator, true, Star.INSTANCE );
sqlAppender.appendSql( ',' );
renderCommaSeparatedList( sqlAppender, translator, expressions );
sqlAppender.appendSql( ')' );
}
else {
renderSimpleArgument( sqlAppender, filter, translator, true, Star.INSTANCE );
sqlAppender.appendSql( ',' );
renderCommaSeparatedList( sqlAppender, translator, expressions );
}
}
// Rendering the tuple will add parenthesis around
else if ( requiresParenthesis ) {
translator.render( tuple, SqlAstNodeRenderingMode.DEFAULT );
}
else {
renderCommaSeparatedList( sqlAppender, translator, expressions );
}
}
private void renderCommaSeparatedList(
SqlAppender sqlAppender,
SqlAstTranslator<?> translator,
List<? extends Expression> expressions) {
translator.render( expressions.get( 0 ), SqlAstNodeRenderingMode.DEFAULT );
for ( int i = 1; i < expressions.size(); i++ ) {
sqlAppender.appendSql( ',' );
translator.render( expressions.get( i ), SqlAstNodeRenderingMode.DEFAULT );
}
}
private void renderSimpleArgument(
SqlAppender sqlAppender,
Predicate filter,
SqlAstTranslator<?> translator,
boolean caseWrapper,
SqlAstNode realArg) {
if ( caseWrapper ) {
sqlAppender.appendSql( "case when " );
filter.accept( translator );
@ -124,12 +255,123 @@ public class CountFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
else {
translator.render( realArg, SqlAstNodeRenderingMode.DEFAULT );
}
sqlAppender.appendSql( ')' );
if ( filter != null && !caseWrapper ) {
sqlAppender.appendSql( " filter (where " );
filter.accept( translator );
sqlAppender.appendSql( ')' );
}
private void renderCastedArgument(SqlAppender sqlAppender, SqlAstTranslator<?> translator, Expression realArg) {
if ( concatArgumentCastType == null ) {
translator.render( realArg, SqlAstNodeRenderingMode.DEFAULT );
}
else {
final JdbcMapping sourceMapping = realArg.getExpressionType().getJdbcMappings().get( 0 );
// No need to cast if we already have a string
if ( sourceMapping.getCastType() == CastType.STRING ) {
translator.render( realArg, SqlAstNodeRenderingMode.DEFAULT );
}
else {
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.STRING );
new PatternRenderer( cast.replace( "?2", concatArgumentCastType ) )
.render( sqlAppender, Collections.singletonList( realArg ), translator );
}
}
}
private boolean canReplaceWithStar(SqlAstNode arg, SqlAstTranslator<?> translator) {
// To determine if we can replace the argument with a star, we must know if the argument is nullable
if ( arg instanceof AbstractSqmPathInterpretation<?> ) {
final AbstractSqmPathInterpretation<?> pathInterpretation = (AbstractSqmPathInterpretation<?>) arg;
final TableGroup tableGroup = pathInterpretation.getTableGroup();
final Expression sqlExpression = pathInterpretation.getSqlExpression();
final JdbcMappingContainer expressionType = sqlExpression.getExpressionType();
// The entity identifier mapping is always considered non-nullable
final boolean isNonNullable = expressionType instanceof EntityIdentifierMapping;
// If canUseInnerJoins is given for a table group, this means that it is non-optional
// But we also have to check if it contains joins that could alter the nullability (RIGHT or FULL)
if ( isNonNullable && tableGroup.canUseInnerJoins() && !hasJoinsAlteringNullability( tableGroup ) ) {
// COUNT can only be used in query specs as query groups can only refer positionally in the order by
final QuerySpec querySpec = (QuerySpec) translator.getCurrentQueryPart();
// On top of this, we also have to ensure that there are no neighbouring joins that alter nullability
for ( TableGroup root : querySpec.getFromClause().getRoots() ) {
final Boolean result = hasNeighbouringJoinsAlteringNullability( root, tableGroup );
if ( result != null ) {
return !result;
}
}
return true;
}
}
return false;
}
private Boolean hasNeighbouringJoinsAlteringNullability(TableGroup tableGroup, TableGroup targetTableGroup) {
if ( tableGroup == targetTableGroup ) {
return Boolean.FALSE;
}
final List<TableGroupJoin> tableGroupJoins = tableGroup.getTableGroupJoins();
int tableGroupIndex = -1;
for ( int i = 0; i < tableGroupJoins.size(); i++ ) {
final TableGroupJoin tableGroupJoin = tableGroupJoins.get( i );
final Boolean result = hasNeighbouringJoinsAlteringNullability(
tableGroupJoin.getJoinedGroup(),
targetTableGroup
);
if ( result == Boolean.TRUE ) {
return Boolean.TRUE;
}
else if ( result != null ) {
tableGroupIndex = i;
break;
}
}
if ( tableGroupIndex != -1 ) {
for ( int i = 0; i < tableGroupJoins.size(); i++ ) {
if ( i == tableGroupIndex ) {
continue;
}
final TableGroupJoin tableGroupJoin = tableGroupJoins.get( i );
if ( hasJoinsAlteringNullability( tableGroupJoin ) ) {
return Boolean.TRUE;
}
}
return Boolean.FALSE;
}
return null;
}
private boolean hasJoinsAlteringNullability(TableGroup tableGroup) {
for ( TableGroupJoin tableGroupJoin : tableGroup.getTableGroupJoins() ) {
switch ( tableGroupJoin.getJoinType() ) {
case INNER:
case LEFT:
case CROSS:
if ( hasJoinsAlteringNullability( tableGroupJoin.getJoinedGroup() ) ) {
return true;
}
break;
default:
// Other joins affect the nullability
return true;
}
}
return false;
}
private boolean hasJoinsAlteringNullability(TableGroupJoin neighbourJoin) {
switch ( neighbourJoin.getJoinType() ) {
case INNER:
case LEFT:
case CROSS:
if ( hasJoinsAlteringNullability( neighbourJoin.getJoinedGroup() ) ) {
return true;
}
break;
default:
// Other joins affect the nullability
return true;
}
return false;
}
@Override

View File

@ -6,34 +6,56 @@
*/
package org.hibernate.dialect.function;
import java.util.Collections;
import java.util.List;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.query.CastType;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.StandardBasicTypes;
public class DerbyConcatFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
public DerbyConcatFunction() {
private final Dialect dialect;
public DerbyConcatFunction(Dialect dialect) {
super(
"concat",
StandardArgumentsValidators.min( 1 ),
StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.STRING )
);
this.dialect = dialect;
}
@Override
public void render(SqlAppender sqlAppender, List<SqlAstNode> sqlAstArguments, SqlAstTranslator<?> walker) {
sqlAppender.appendSql( '(' );
walker.render( sqlAstArguments.get( 0 ), SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
renderAsString( sqlAppender, walker, (Expression) sqlAstArguments.get( 0 ) );
for ( int i = 1; i < sqlAstArguments.size(); i++ ) {
sqlAppender.appendSql( "||" );
walker.render( sqlAstArguments.get( i ), SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
renderAsString( sqlAppender, walker, (Expression) sqlAstArguments.get( i ) );
}
sqlAppender.appendSql( ')' );
}
private void renderAsString(SqlAppender sqlAppender, SqlAstTranslator<?> translator, Expression expression) {
final JdbcMapping sourceMapping = expression.getExpressionType().getJdbcMappings().get( 0 );
// No need to cast if we already have a string
if ( sourceMapping.getCastType() == CastType.STRING ) {
translator.render( expression, SqlAstNodeRenderingMode.NO_PLAIN_PARAMETER );
}
else {
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.STRING );
new PatternRenderer( cast ).render( sqlAppender, Collections.singletonList( expression ), translator );
}
}
}

View File

@ -0,0 +1,61 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later
* See the lgpl.txt file in the root directory or http://www.gnu.org/licenses/lgpl-2.1.html
*/
package org.hibernate.dialect.function;
import java.util.Collections;
import java.util.List;
import org.hibernate.dialect.Dialect;
import org.hibernate.metamodel.mapping.JdbcMapping;
import org.hibernate.query.CastType;
import org.hibernate.query.sqm.function.AbstractSqmSelfRenderingFunctionDescriptor;
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
import org.hibernate.query.sqm.produce.function.internal.PatternRenderer;
import org.hibernate.sql.ast.SqlAstNodeRenderingMode;
import org.hibernate.sql.ast.SqlAstTranslator;
import org.hibernate.sql.ast.spi.SqlAppender;
import org.hibernate.sql.ast.tree.SqlAstNode;
import org.hibernate.sql.ast.tree.expression.Expression;
import org.hibernate.type.StandardBasicTypes;
public class SybaseConcatFunction extends AbstractSqmSelfRenderingFunctionDescriptor {
private final Dialect dialect;
public SybaseConcatFunction(Dialect dialect) {
super(
"concat",
StandardArgumentsValidators.min( 1 ),
StandardFunctionReturnTypeResolvers.invariant( StandardBasicTypes.STRING )
);
this.dialect = dialect;
}
@Override
public void render(SqlAppender sqlAppender, List<SqlAstNode> sqlAstArguments, SqlAstTranslator<?> walker) {
sqlAppender.appendSql( '(' );
renderAsString( sqlAppender, walker, (Expression) sqlAstArguments.get( 0 ) );
for ( int i = 1; i < sqlAstArguments.size(); i++ ) {
sqlAppender.appendSql( '+' );
renderAsString( sqlAppender, walker, (Expression) sqlAstArguments.get( i ) );
}
sqlAppender.appendSql( ')' );
}
private void renderAsString(SqlAppender sqlAppender, SqlAstTranslator<?> translator, Expression expression) {
final JdbcMapping sourceMapping = expression.getExpressionType().getJdbcMappings().get( 0 );
// No need to cast if we already have a string
if ( sourceMapping.getCastType() == CastType.STRING ) {
translator.render( expression, SqlAstNodeRenderingMode.DEFAULT );
}
else {
final String cast = dialect.castPattern( sourceMapping.getCastType(), CastType.STRING );
new PatternRenderer( cast ).render( sqlAppender, Collections.singletonList( expression ), translator );
}
}
}

View File

@ -72,7 +72,6 @@ public class CountExpressionTest extends BaseCoreFunctionalTestCase {
@Test
@TestForIssue(jiraKey = "HHH-9182")
@SkipForDialect(value = DerbyDialect.class, comment = "Derby can't cast from integer to varchar i.e. it requires an intermediary step")
public void testCountDistinctExpression() {
doInHibernate( this::sessionFactory, session -> {
List results = session.createQuery(
@ -91,6 +90,26 @@ public class CountExpressionTest extends BaseCoreFunctionalTestCase {
} );
}
@Test
@TestForIssue(jiraKey = "HHH-11042")
public void testCountDistinctTuple() {
doInHibernate( this::sessionFactory, session -> {
List results = session.createQuery(
"SELECT " +
" d.id, " +
" COUNT(DISTINCT (KEY(l), l)) " +
"FROM Document d " +
"LEFT JOIN d.contacts c " +
"LEFT JOIN c.localized l " +
"GROUP BY d.id")
.getResultList();
assertEquals(1, results.size());
Object[] tuple = (Object[]) results.get( 0 );
assertEquals(1, tuple[0]);
} );
}
@Entity(name = "Document")
public static class Document {