diff --git a/hibernate-core/src/main/java/org/hibernate/sql/Template.java b/hibernate-core/src/main/java/org/hibernate/sql/Template.java index 05c5a78d02..abc62d1170 100644 --- a/hibernate-core/src/main/java/org/hibernate/sql/Template.java +++ b/hibernate-core/src/main/java/org/hibernate/sql/Template.java @@ -19,6 +19,8 @@ import org.hibernate.query.sqm.function.SqmFunctionDescriptor; import org.hibernate.query.sqm.function.SqmFunctionRegistry; import org.hibernate.type.spi.TypeConfiguration; +import static java.lang.Boolean.parseBoolean; +import static java.lang.Character.isLetter; import static org.hibernate.internal.util.StringHelper.WHITESPACE; /** @@ -92,6 +94,7 @@ public final class Template { LITERAL_PREFIXES.add("date"); LITERAL_PREFIXES.add("time"); LITERAL_PREFIXES.add("timestamp"); + LITERAL_PREFIXES.add("zone"); } public static final String TEMPLATE = "$PlaceHolder$"; @@ -139,12 +142,9 @@ public final class Template { // which the tokens occur. Depending on the state of those flags we decide whether we need to qualify // identifier references. - String symbols = PUNCTUATION + - WHITESPACE + - dialect.openQuote() + - dialect.closeQuote(); - StringTokenizer tokens = new StringTokenizer( sqlWhereString, symbols, true ); - StringBuilder result = new StringBuilder(); + final String symbols = PUNCTUATION + WHITESPACE + dialect.openQuote() + dialect.closeQuote(); + final StringTokenizer tokens = new StringTokenizer( sqlWhereString, symbols, true ); + final StringBuilder result = new StringBuilder(); boolean quoted = false; boolean quotedIdentifier = false; @@ -168,7 +168,7 @@ public final class Template { } if ( !quoted ) { - boolean isOpenQuote; + final boolean isOpenQuote; if ( "`".equals(token) ) { isOpenQuote = !quotedIdentifier; token = lcToken = isOpenQuote @@ -177,182 +177,27 @@ public final class Template { quotedIdentifier = isOpenQuote; isQuoteCharacter = true; } - else if ( !quotedIdentifier && ( dialect.openQuote()==token.charAt(0) ) ) { + else if ( !quotedIdentifier && dialect.openQuote()==token.charAt(0) ) { isOpenQuote = true; quotedIdentifier = true; isQuoteCharacter = true; } - else if ( quotedIdentifier && ( dialect.closeQuote()==token.charAt(0) ) ) { + else if ( quotedIdentifier && dialect.closeQuote()==token.charAt(0) ) { quotedIdentifier = false; isQuoteCharacter = true; isOpenQuote = false; } - else if ( LITERAL_PREFIXES.contains( lcToken ) ) { - if ( "'".equals( nextToken ) ) { - // Don't prefix a literal - result.append( token ); - continue; - } - else if ( nextToken != null && Character.isWhitespace( nextToken.charAt( 0 ) ) ) { - final StringBuilder additionalTokens = new StringBuilder(); - TimeZoneTokens possibleNextToken = null; - do { - possibleNextToken = possibleNextToken == null - ? TimeZoneTokens.getPossibleNextTokens( lcToken ) - : possibleNextToken.nextToken(); - do { - additionalTokens.append( nextToken ); - hasMore = tokens.hasMoreTokens(); - nextToken = tokens.nextToken(); - } while ( nextToken != null && Character.isWhitespace( nextToken.charAt( 0 ) ) ); - } while ( nextToken != null && possibleNextToken.isToken( nextToken ) ); - if ( "'".equals( nextToken ) ) { - // Don't prefix a literal - result.append( token ); - result.append( additionalTokens ); - continue; - } - else { - isOpenQuote = false; - } - } - else { - isOpenQuote = false; - } - } else { isOpenQuote = false; } - if ( isOpenQuote ) { result.append( placeholder ).append( '.' ); } } - // Special processing for ANSI SQL EXTRACT function - if ( "extract".equals( lcToken ) && "(".equals( nextToken ) ) { - final String field = extractUntil( tokens, "from" ); - final String source = renderWhereStringTemplate( - extractUntil( tokens, ")" ), - placeholder, - dialect, - typeConfiguration, - functionRegistry - ); - result.append( "extract(" ).append( field ).append( " from " ).append( source ).append( ')' ); - - hasMore = tokens.hasMoreTokens(); - nextToken = hasMore ? tokens.nextToken() : null; - - continue; - } - - // Special processing for ANSI SQL TRIM function - if ( "trim".equals( lcToken ) && "(".equals( nextToken ) ) { - List operands = new ArrayList<>(); - StringBuilder builder = new StringBuilder(); - - boolean hasMoreOperands = true; - String operandToken = tokens.nextToken(); - switch ( operandToken.toLowerCase( Locale.ROOT ) ) { - case "leading": - case "trailing": - case "both": - operands.add( operandToken ); - if ( hasMoreOperands = tokens.hasMoreTokens() ) { - operandToken = tokens.nextToken(); - } - break; - } - boolean quotedOperand = false; - int parenthesis = 0; - while ( hasMoreOperands ) { - final boolean isQuote = "'".equals( operandToken ); - if ( isQuote ) { - quotedOperand = !quotedOperand; - if ( !quotedOperand ) { - operands.add( builder.append( '\'' ).toString() ); - builder.setLength( 0 ); - } - else { - builder.append( '\'' ); - } - } - else if ( quotedOperand ) { - builder.append( operandToken ); - } - else if ( parenthesis != 0 ) { - builder.append( operandToken ); - switch ( operandToken ) { - case "(": - parenthesis++; - break; - case ")": - parenthesis--; - break; - } - } - else { - builder.append( operandToken ); - switch ( operandToken.toLowerCase( Locale.ROOT ) ) { - case "(": - parenthesis++; - break; - case ")": - parenthesis--; - break; - case "from": - if ( builder.length() != 0 ) { - operands.add( builder.substring( 0, builder.length() - 4 ) ); - builder.setLength( 0 ); - operands.add( operandToken ); - } - break; - } - } - operandToken = tokens.nextToken(); - hasMoreOperands = tokens.hasMoreTokens() && ( parenthesis != 0 || ! ")".equals( operandToken ) ); - } - if ( builder.length() != 0 ) { - operands.add( builder.toString() ); - } - - TrimOperands trimOperands = new TrimOperands( operands ); - result.append( "trim(" ); - if ( trimOperands.trimSpec != null ) { - result.append( trimOperands.trimSpec ).append( ' ' ); - } - if ( trimOperands.trimChar != null ) { - if ( trimOperands.trimChar.startsWith( "'" ) && trimOperands.trimChar.endsWith( "'" ) ) { - result.append( trimOperands.trimChar ); - } - else { - result.append( - renderWhereStringTemplate( trimOperands.trimSpec, placeholder, dialect, typeConfiguration, functionRegistry ) - ); - } - result.append( ' ' ); - } - if ( trimOperands.from != null ) { - result.append( trimOperands.from ).append( ' ' ); - } - else if ( trimOperands.trimSpec != null || trimOperands.trimChar != null ) { - // I think ANSI SQL says that the 'from' is not optional if either trim-spec or trim-char is specified - result.append( "from " ); - } - - result.append( renderWhereStringTemplate( trimOperands.trimSource, placeholder, dialect, typeConfiguration, functionRegistry ) ) - .append( ')' ); - - hasMore = tokens.hasMoreTokens(); - nextToken = hasMore ? tokens.nextToken() : null; - - continue; - } - - boolean quotedOrWhitespace = quoted || quotedIdentifier || isQuoteCharacter - || Character.isWhitespace( token.charAt(0) ); - + final boolean quotedOrWhitespace = + quoted || quotedIdentifier || isQuoteCharacter + || token.isBlank(); if ( quotedOrWhitespace ) { result.append( token ); } @@ -370,8 +215,21 @@ public final class Template { else if ( isNamedParameter(token) ) { result.append(token); } + else if ( isExtractFunction( lcToken, nextToken ) ) { + // Special processing for ANSI SQL EXTRACT function + handleExtractFunction( placeholder, dialect, typeConfiguration, functionRegistry, tokens, result ); + hasMore = tokens.hasMoreTokens(); + nextToken = hasMore ? tokens.nextToken() : null; + } + else if ( isTrimFunction( lcToken, nextToken ) ) { + // Special processing for ANSI SQL TRIM function + handleTrimFunction( placeholder, dialect, typeConfiguration, functionRegistry, tokens, result ); + hasMore = tokens.hasMoreTokens(); + nextToken = hasMore ? tokens.nextToken() : null; + } else if ( isIdentifier(token) - && !isFunctionOrKeyword(lcToken, nextToken, dialect, typeConfiguration, functionRegistry) ) { + && !isFunctionOrKeyword( lcToken, nextToken, dialect, typeConfiguration, functionRegistry ) + && !isLiteral( lcToken, nextToken, sqlWhereString, symbols, tokens ) ) { result.append(placeholder) .append('.') .append( dialect.quote(token) ); @@ -385,7 +243,7 @@ public final class Template { beforeTable = true; } if ( isBoolean( token ) ) { - token = dialect.toBooleanValueString( Boolean.parseBoolean( token ) ); + token = dialect.toBooleanValueString( parseBoolean( token ) ); } result.append(token); } @@ -401,39 +259,175 @@ public final class Template { return result.toString(); } - private enum TimeZoneTokens { - NONE, - WITH, - TIME, - ZONE; + private static boolean isTrimFunction(String lcToken, String nextToken) { + return "trim".equals(lcToken) && "(".equals(nextToken); + } - static TimeZoneTokens getPossibleNextTokens(String lctoken) { - switch ( lctoken ) { - case "time": - case "timestamp": - return WITH; - default: - return NONE; - } - } + private static boolean isExtractFunction(String lcToken, String nextToken) { + return "extract".equals(lcToken) && "(".equals(nextToken); + } - public TimeZoneTokens nextToken() { - if ( this == WITH ) { - return TIME; + private static boolean isLiteral( + String lcToken, String next, + String sqlWhereString, String symbols, StringTokenizer tokens) { + if ( LITERAL_PREFIXES.contains( lcToken ) && next != null ) { + // easy cases first + if ( "'".equals(next) ) { + return true; } - else if ( this == TIME ) { - return ZONE; + else if ( !next.isBlank() ) { + return false; } else { - return NONE; + // we need to look ahead in the token stream + // to find the first non-blank token + final StringTokenizer lookahead = + new StringTokenizer( sqlWhereString, symbols, true ); + while ( lookahead.countTokens() > tokens.countTokens()+1 ) { + lookahead.nextToken(); + } + if ( lookahead.hasMoreTokens() ) { + String nextToken; + do { + nextToken = lookahead.nextToken().toLowerCase(Locale.ROOT); + } + while ( nextToken.isBlank() && lookahead.hasMoreTokens() ); + return "'".equals( nextToken ) + || lcToken.equals( "time" ) && "with".equals( nextToken ) + || lcToken.equals( "timestamp" ) && "with".equals( nextToken ) + || lcToken.equals( "time" ) && "zone".equals( nextToken ); + } + else { + return false; + } } } - - public boolean isToken(String token) { - return this != NONE && name().equalsIgnoreCase( token ); + else { + return false; } } + private static void handleTrimFunction( + String placeholder, Dialect dialect, + TypeConfiguration typeConfiguration, + SqmFunctionRegistry functionRegistry, + StringTokenizer tokens, + StringBuilder result) { + final List operands = new ArrayList<>(); + final StringBuilder builder = new StringBuilder(); + + boolean hasMoreOperands = true; + String operandToken = tokens.nextToken(); + switch ( operandToken.toLowerCase( Locale.ROOT ) ) { + case "leading": + case "trailing": + case "both": + operands.add( operandToken ); + if ( hasMoreOperands = tokens.hasMoreTokens() ) { + operandToken = tokens.nextToken(); + } + break; + } + boolean quotedOperand = false; + int parenthesis = 0; + while ( hasMoreOperands ) { + final boolean isQuote = "'".equals( operandToken ); + if ( isQuote ) { + quotedOperand = !quotedOperand; + if ( !quotedOperand ) { + operands.add( builder.append( '\'' ).toString() ); + builder.setLength( 0 ); + } + else { + builder.append( '\'' ); + } + } + else if ( quotedOperand ) { + builder.append( operandToken ); + } + else if ( parenthesis != 0 ) { + builder.append( operandToken ); + switch ( operandToken ) { + case "(": + parenthesis++; + break; + case ")": + parenthesis--; + break; + } + } + else { + builder.append( operandToken ); + switch ( operandToken.toLowerCase( Locale.ROOT ) ) { + case "(": + parenthesis++; + break; + case ")": + parenthesis--; + break; + case "from": + if ( !builder.isEmpty() ) { + operands.add( builder.substring( 0, builder.length() - 4 ) ); + builder.setLength( 0 ); + operands.add( operandToken ); + } + break; + } + } + operandToken = tokens.nextToken(); + hasMoreOperands = tokens.hasMoreTokens() + && ( parenthesis != 0 || ! ")".equals( operandToken ) ); + } + if ( !builder.isEmpty() ) { + operands.add( builder.toString() ); + } + + final TrimOperands trimOperands = new TrimOperands( operands ); + result.append( "trim(" ); + if ( trimOperands.trimSpec != null ) { + result.append( trimOperands.trimSpec ).append( ' ' ); + } + if ( trimOperands.trimChar != null ) { + if ( trimOperands.trimChar.startsWith( "'" ) && trimOperands.trimChar.endsWith( "'" ) ) { + result.append( trimOperands.trimChar ); + } + else { + result.append( + renderWhereStringTemplate( trimOperands.trimSpec, placeholder, dialect, typeConfiguration, functionRegistry ) + ); + } + result.append( ' ' ); + } + if ( trimOperands.from != null ) { + result.append( trimOperands.from ).append( ' ' ); + } + else if ( trimOperands.trimSpec != null || trimOperands.trimChar != null ) { + // I think ANSI SQL says that the 'from' is not optional if either trim-spec or trim-char is specified + result.append( "from " ); + } + + result.append( renderWhereStringTemplate( trimOperands.trimSource, placeholder, dialect, typeConfiguration, functionRegistry ) ) + .append( ')' ); + } + + private static void handleExtractFunction( + String placeholder, + Dialect dialect, + TypeConfiguration typeConfiguration, + SqmFunctionRegistry functionRegistry, + StringTokenizer tokens, + StringBuilder result) { + final String field = extractUntil( tokens, "from" ); + final String source = renderWhereStringTemplate( + extractUntil( tokens, ")" ), + placeholder, + dialect, + typeConfiguration, + functionRegistry + ); + result.append( "extract(" ).append( field ).append( " from " ).append( source ).append( ')' ); + } + public static List collectColumnNames( String sql, Dialect dialect, @@ -754,7 +748,7 @@ public final class Template { } private static String extractUntil(StringTokenizer tokens, String delimiter) { - StringBuilder valueBuilder = new StringBuilder(); + final StringBuilder valueBuilder = new StringBuilder(); String token = tokens.nextToken(); while ( ! delimiter.equalsIgnoreCase( token ) ) { valueBuilder.append( token ); @@ -773,12 +767,21 @@ public final class Template { Dialect dialect, TypeConfiguration typeConfiguration, SqmFunctionRegistry functionRegistry) { - return "(".equals( nextToken ) || - KEYWORDS.contains( lcToken ) || - isType( lcToken, typeConfiguration ) || - isFunction( lcToken, nextToken, functionRegistry ) || - dialect.getKeywords().contains( lcToken ) || - FUNCTION_KEYWORDS.contains( lcToken ); + if ( "(".equals( nextToken ) ) { + return true; + } + else if ( "date".equals( lcToken ) || "time".equals( lcToken ) ) { + // these can be column names on some databases + // TODO: treat 'current date' as a function + return false; + } + else { + return KEYWORDS.contains( lcToken ) + || isType( lcToken, typeConfiguration ) + || isFunction( lcToken, nextToken, functionRegistry ) + || dialect.getKeywords().contains( lcToken ) + || FUNCTION_KEYWORDS.contains( lcToken ); + } } private static boolean isType(String lcToken, TypeConfiguration typeConfiguration) { @@ -800,10 +803,11 @@ public final class Template { if ( isBoolean( token ) ) { return false; } - return token.charAt( 0 ) == '`' || ( //allow any identifier quoted with backtick - Character.isLetter( token.charAt( 0 ) ) && //only recognizes identifiers beginning with a letter + return token.charAt( 0 ) == '`' + || ( //allow any identifier quoted with backtick + isLetter( token.charAt( 0 ) ) && //only recognizes identifiers beginning with a letter token.indexOf( '.' ) < 0 - ); + ); } private static boolean isBoolean(String token) { diff --git a/hibernate-core/src/test/java/org/hibernate/orm/test/sql/TemplateTest.java b/hibernate-core/src/test/java/org/hibernate/orm/test/sql/TemplateTest.java index 87a0cb1a7d..96c3f0253e 100644 --- a/hibernate-core/src/test/java/org/hibernate/orm/test/sql/TemplateTest.java +++ b/hibernate-core/src/test/java/org/hibernate/orm/test/sql/TemplateTest.java @@ -7,13 +7,9 @@ package org.hibernate.orm.test.sql; -import org.hibernate.dialect.Dialect; import org.hibernate.engine.spi.SessionFactoryImplementor; -import org.hibernate.query.sqm.function.SqmFunctionRegistry; import org.hibernate.sql.Template; -import org.hibernate.type.spi.TypeConfiguration; -import org.hibernate.testing.orm.domain.StandardDomainModel; import org.hibernate.testing.orm.junit.DomainModel; import org.hibernate.testing.orm.junit.JiraKey; import org.hibernate.testing.orm.junit.SessionFactory; @@ -23,23 +19,33 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @SessionFactory -@DomainModel(standardModels = StandardDomainModel.GAMBIT) +@DomainModel public class TemplateTest { @Test @JiraKey("HHH-18256") public void templateLiterals(SessionFactoryScope scope) { - assertWhereStringTemplate( "N'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "X'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "BX'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "VARBYTE'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "bytea 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "bytea 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "date 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "time 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "timestamp 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "timestamp with time zone 'a'", scope.getSessionFactory() ); - assertWhereStringTemplate( "time with time zone 'a'", scope.getSessionFactory() ); + SessionFactoryImplementor factory = scope.getSessionFactory(); + assertWhereStringTemplate( "N'a'", factory ); + assertWhereStringTemplate( "X'a'", factory ); + assertWhereStringTemplate( "BX'a'", factory); + assertWhereStringTemplate( "VARBYTE'a'", factory ); + assertWhereStringTemplate( "bytea 'a'", factory ); + assertWhereStringTemplate( "bytea 'a'", factory ); + assertWhereStringTemplate( "date 'a'", factory ); + assertWhereStringTemplate( "time 'a'", factory ); + assertWhereStringTemplate( "timestamp 'a'", factory ); + assertWhereStringTemplate( "timestamp with time zone 'a'", factory ); + assertWhereStringTemplate( "time with time zone 'a'", factory ); + assertWhereStringTemplate( "date", "$PlaceHolder$.date", factory ); + assertWhereStringTemplate( "time", "$PlaceHolder$.time", factory ); + assertWhereStringTemplate( "zone", "$PlaceHolder$.zone", factory ); + assertWhereStringTemplate("select date from thetable", + "select $PlaceHolder$.date from thetable", factory ); + assertWhereStringTemplate("select date '2000-12-1' from thetable", + "select date '2000-12-1' from thetable", factory ); + assertWhereStringTemplate("where date between date '2000-12-1' and date '2002-12-2'", + "where $PlaceHolder$.date between date '2000-12-1' and date '2002-12-2'", factory ); } private static void assertWhereStringTemplate(String sql, SessionFactoryImplementor sf) { @@ -52,4 +58,14 @@ public class TemplateTest { assertEquals( sql, template ); } + private static void assertWhereStringTemplate(String sql, String result, SessionFactoryImplementor sf) { + final String template = Template.renderWhereStringTemplate( + sql, + sf.getJdbcServices().getDialect(), + sf.getTypeConfiguration(), + sf.getQueryEngine().getSqmFunctionRegistry() + ); + assertEquals( result, template ); + } + }