HHH-18570 reallow use of 'date' and 'time' as regular column names in parsed SQL

attempt to simplify the logic here

Signed-off-by: Gavin King <gavin@hibernate.org>
This commit is contained in:
Gavin King 2024-09-05 14:28:28 +02:00
parent cb0268a618
commit 35f96d6e90
2 changed files with 238 additions and 218 deletions

View File

@ -19,6 +19,8 @@ import org.hibernate.query.sqm.function.SqmFunctionDescriptor;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.type.spi.TypeConfiguration;
import static java.lang.Boolean.parseBoolean;
import static java.lang.Character.isLetter;
import static org.hibernate.internal.util.StringHelper.WHITESPACE;
/**
@ -92,6 +94,7 @@ public final class Template {
LITERAL_PREFIXES.add("date");
LITERAL_PREFIXES.add("time");
LITERAL_PREFIXES.add("timestamp");
LITERAL_PREFIXES.add("zone");
}
public static final String TEMPLATE = "$PlaceHolder$";
@ -139,12 +142,9 @@ public final class Template {
// which the tokens occur. Depending on the state of those flags we decide whether we need to qualify
// identifier references.
String symbols = PUNCTUATION +
WHITESPACE +
dialect.openQuote() +
dialect.closeQuote();
StringTokenizer tokens = new StringTokenizer( sqlWhereString, symbols, true );
StringBuilder result = new StringBuilder();
final String symbols = PUNCTUATION + WHITESPACE + dialect.openQuote() + dialect.closeQuote();
final StringTokenizer tokens = new StringTokenizer( sqlWhereString, symbols, true );
final StringBuilder result = new StringBuilder();
boolean quoted = false;
boolean quotedIdentifier = false;
@ -168,7 +168,7 @@ public final class Template {
}
if ( !quoted ) {
boolean isOpenQuote;
final boolean isOpenQuote;
if ( "`".equals(token) ) {
isOpenQuote = !quotedIdentifier;
token = lcToken = isOpenQuote
@ -177,182 +177,27 @@ public final class Template {
quotedIdentifier = isOpenQuote;
isQuoteCharacter = true;
}
else if ( !quotedIdentifier && ( dialect.openQuote()==token.charAt(0) ) ) {
else if ( !quotedIdentifier && dialect.openQuote()==token.charAt(0) ) {
isOpenQuote = true;
quotedIdentifier = true;
isQuoteCharacter = true;
}
else if ( quotedIdentifier && ( dialect.closeQuote()==token.charAt(0) ) ) {
else if ( quotedIdentifier && dialect.closeQuote()==token.charAt(0) ) {
quotedIdentifier = false;
isQuoteCharacter = true;
isOpenQuote = false;
}
else if ( LITERAL_PREFIXES.contains( lcToken ) ) {
if ( "'".equals( nextToken ) ) {
// Don't prefix a literal
result.append( token );
continue;
}
else if ( nextToken != null && Character.isWhitespace( nextToken.charAt( 0 ) ) ) {
final StringBuilder additionalTokens = new StringBuilder();
TimeZoneTokens possibleNextToken = null;
do {
possibleNextToken = possibleNextToken == null
? TimeZoneTokens.getPossibleNextTokens( lcToken )
: possibleNextToken.nextToken();
do {
additionalTokens.append( nextToken );
hasMore = tokens.hasMoreTokens();
nextToken = tokens.nextToken();
} while ( nextToken != null && Character.isWhitespace( nextToken.charAt( 0 ) ) );
} while ( nextToken != null && possibleNextToken.isToken( nextToken ) );
if ( "'".equals( nextToken ) ) {
// Don't prefix a literal
result.append( token );
result.append( additionalTokens );
continue;
}
else {
isOpenQuote = false;
}
}
else {
isOpenQuote = false;
}
}
else {
isOpenQuote = false;
}
if ( isOpenQuote ) {
result.append( placeholder ).append( '.' );
}
}
// Special processing for ANSI SQL EXTRACT function
if ( "extract".equals( lcToken ) && "(".equals( nextToken ) ) {
final String field = extractUntil( tokens, "from" );
final String source = renderWhereStringTemplate(
extractUntil( tokens, ")" ),
placeholder,
dialect,
typeConfiguration,
functionRegistry
);
result.append( "extract(" ).append( field ).append( " from " ).append( source ).append( ')' );
hasMore = tokens.hasMoreTokens();
nextToken = hasMore ? tokens.nextToken() : null;
continue;
}
// Special processing for ANSI SQL TRIM function
if ( "trim".equals( lcToken ) && "(".equals( nextToken ) ) {
List<String> operands = new ArrayList<>();
StringBuilder builder = new StringBuilder();
boolean hasMoreOperands = true;
String operandToken = tokens.nextToken();
switch ( operandToken.toLowerCase( Locale.ROOT ) ) {
case "leading":
case "trailing":
case "both":
operands.add( operandToken );
if ( hasMoreOperands = tokens.hasMoreTokens() ) {
operandToken = tokens.nextToken();
}
break;
}
boolean quotedOperand = false;
int parenthesis = 0;
while ( hasMoreOperands ) {
final boolean isQuote = "'".equals( operandToken );
if ( isQuote ) {
quotedOperand = !quotedOperand;
if ( !quotedOperand ) {
operands.add( builder.append( '\'' ).toString() );
builder.setLength( 0 );
}
else {
builder.append( '\'' );
}
}
else if ( quotedOperand ) {
builder.append( operandToken );
}
else if ( parenthesis != 0 ) {
builder.append( operandToken );
switch ( operandToken ) {
case "(":
parenthesis++;
break;
case ")":
parenthesis--;
break;
}
}
else {
builder.append( operandToken );
switch ( operandToken.toLowerCase( Locale.ROOT ) ) {
case "(":
parenthesis++;
break;
case ")":
parenthesis--;
break;
case "from":
if ( builder.length() != 0 ) {
operands.add( builder.substring( 0, builder.length() - 4 ) );
builder.setLength( 0 );
operands.add( operandToken );
}
break;
}
}
operandToken = tokens.nextToken();
hasMoreOperands = tokens.hasMoreTokens() && ( parenthesis != 0 || ! ")".equals( operandToken ) );
}
if ( builder.length() != 0 ) {
operands.add( builder.toString() );
}
TrimOperands trimOperands = new TrimOperands( operands );
result.append( "trim(" );
if ( trimOperands.trimSpec != null ) {
result.append( trimOperands.trimSpec ).append( ' ' );
}
if ( trimOperands.trimChar != null ) {
if ( trimOperands.trimChar.startsWith( "'" ) && trimOperands.trimChar.endsWith( "'" ) ) {
result.append( trimOperands.trimChar );
}
else {
result.append(
renderWhereStringTemplate( trimOperands.trimSpec, placeholder, dialect, typeConfiguration, functionRegistry )
);
}
result.append( ' ' );
}
if ( trimOperands.from != null ) {
result.append( trimOperands.from ).append( ' ' );
}
else if ( trimOperands.trimSpec != null || trimOperands.trimChar != null ) {
// I think ANSI SQL says that the 'from' is not optional if either trim-spec or trim-char is specified
result.append( "from " );
}
result.append( renderWhereStringTemplate( trimOperands.trimSource, placeholder, dialect, typeConfiguration, functionRegistry ) )
.append( ')' );
hasMore = tokens.hasMoreTokens();
nextToken = hasMore ? tokens.nextToken() : null;
continue;
}
boolean quotedOrWhitespace = quoted || quotedIdentifier || isQuoteCharacter
|| Character.isWhitespace( token.charAt(0) );
final boolean quotedOrWhitespace =
quoted || quotedIdentifier || isQuoteCharacter
|| token.isBlank();
if ( quotedOrWhitespace ) {
result.append( token );
}
@ -370,8 +215,21 @@ public final class Template {
else if ( isNamedParameter(token) ) {
result.append(token);
}
else if ( isExtractFunction( lcToken, nextToken ) ) {
// Special processing for ANSI SQL EXTRACT function
handleExtractFunction( placeholder, dialect, typeConfiguration, functionRegistry, tokens, result );
hasMore = tokens.hasMoreTokens();
nextToken = hasMore ? tokens.nextToken() : null;
}
else if ( isTrimFunction( lcToken, nextToken ) ) {
// Special processing for ANSI SQL TRIM function
handleTrimFunction( placeholder, dialect, typeConfiguration, functionRegistry, tokens, result );
hasMore = tokens.hasMoreTokens();
nextToken = hasMore ? tokens.nextToken() : null;
}
else if ( isIdentifier(token)
&& !isFunctionOrKeyword(lcToken, nextToken, dialect, typeConfiguration, functionRegistry) ) {
&& !isFunctionOrKeyword( lcToken, nextToken, dialect, typeConfiguration, functionRegistry )
&& !isLiteral( lcToken, nextToken, sqlWhereString, symbols, tokens ) ) {
result.append(placeholder)
.append('.')
.append( dialect.quote(token) );
@ -385,7 +243,7 @@ public final class Template {
beforeTable = true;
}
if ( isBoolean( token ) ) {
token = dialect.toBooleanValueString( Boolean.parseBoolean( token ) );
token = dialect.toBooleanValueString( parseBoolean( token ) );
}
result.append(token);
}
@ -401,39 +259,175 @@ public final class Template {
return result.toString();
}
private enum TimeZoneTokens {
NONE,
WITH,
TIME,
ZONE;
private static boolean isTrimFunction(String lcToken, String nextToken) {
return "trim".equals(lcToken) && "(".equals(nextToken);
}
static TimeZoneTokens getPossibleNextTokens(String lctoken) {
switch ( lctoken ) {
case "time":
case "timestamp":
return WITH;
default:
return NONE;
}
}
private static boolean isExtractFunction(String lcToken, String nextToken) {
return "extract".equals(lcToken) && "(".equals(nextToken);
}
public TimeZoneTokens nextToken() {
if ( this == WITH ) {
return TIME;
private static boolean isLiteral(
String lcToken, String next,
String sqlWhereString, String symbols, StringTokenizer tokens) {
if ( LITERAL_PREFIXES.contains( lcToken ) && next != null ) {
// easy cases first
if ( "'".equals(next) ) {
return true;
}
else if ( this == TIME ) {
return ZONE;
else if ( !next.isBlank() ) {
return false;
}
else {
return NONE;
// we need to look ahead in the token stream
// to find the first non-blank token
final StringTokenizer lookahead =
new StringTokenizer( sqlWhereString, symbols, true );
while ( lookahead.countTokens() > tokens.countTokens()+1 ) {
lookahead.nextToken();
}
if ( lookahead.hasMoreTokens() ) {
String nextToken;
do {
nextToken = lookahead.nextToken().toLowerCase(Locale.ROOT);
}
while ( nextToken.isBlank() && lookahead.hasMoreTokens() );
return "'".equals( nextToken )
|| lcToken.equals( "time" ) && "with".equals( nextToken )
|| lcToken.equals( "timestamp" ) && "with".equals( nextToken )
|| lcToken.equals( "time" ) && "zone".equals( nextToken );
}
else {
return false;
}
}
}
public boolean isToken(String token) {
return this != NONE && name().equalsIgnoreCase( token );
else {
return false;
}
}
private static void handleTrimFunction(
String placeholder, Dialect dialect,
TypeConfiguration typeConfiguration,
SqmFunctionRegistry functionRegistry,
StringTokenizer tokens,
StringBuilder result) {
final List<String> operands = new ArrayList<>();
final StringBuilder builder = new StringBuilder();
boolean hasMoreOperands = true;
String operandToken = tokens.nextToken();
switch ( operandToken.toLowerCase( Locale.ROOT ) ) {
case "leading":
case "trailing":
case "both":
operands.add( operandToken );
if ( hasMoreOperands = tokens.hasMoreTokens() ) {
operandToken = tokens.nextToken();
}
break;
}
boolean quotedOperand = false;
int parenthesis = 0;
while ( hasMoreOperands ) {
final boolean isQuote = "'".equals( operandToken );
if ( isQuote ) {
quotedOperand = !quotedOperand;
if ( !quotedOperand ) {
operands.add( builder.append( '\'' ).toString() );
builder.setLength( 0 );
}
else {
builder.append( '\'' );
}
}
else if ( quotedOperand ) {
builder.append( operandToken );
}
else if ( parenthesis != 0 ) {
builder.append( operandToken );
switch ( operandToken ) {
case "(":
parenthesis++;
break;
case ")":
parenthesis--;
break;
}
}
else {
builder.append( operandToken );
switch ( operandToken.toLowerCase( Locale.ROOT ) ) {
case "(":
parenthesis++;
break;
case ")":
parenthesis--;
break;
case "from":
if ( !builder.isEmpty() ) {
operands.add( builder.substring( 0, builder.length() - 4 ) );
builder.setLength( 0 );
operands.add( operandToken );
}
break;
}
}
operandToken = tokens.nextToken();
hasMoreOperands = tokens.hasMoreTokens()
&& ( parenthesis != 0 || ! ")".equals( operandToken ) );
}
if ( !builder.isEmpty() ) {
operands.add( builder.toString() );
}
final TrimOperands trimOperands = new TrimOperands( operands );
result.append( "trim(" );
if ( trimOperands.trimSpec != null ) {
result.append( trimOperands.trimSpec ).append( ' ' );
}
if ( trimOperands.trimChar != null ) {
if ( trimOperands.trimChar.startsWith( "'" ) && trimOperands.trimChar.endsWith( "'" ) ) {
result.append( trimOperands.trimChar );
}
else {
result.append(
renderWhereStringTemplate( trimOperands.trimSpec, placeholder, dialect, typeConfiguration, functionRegistry )
);
}
result.append( ' ' );
}
if ( trimOperands.from != null ) {
result.append( trimOperands.from ).append( ' ' );
}
else if ( trimOperands.trimSpec != null || trimOperands.trimChar != null ) {
// I think ANSI SQL says that the 'from' is not optional if either trim-spec or trim-char is specified
result.append( "from " );
}
result.append( renderWhereStringTemplate( trimOperands.trimSource, placeholder, dialect, typeConfiguration, functionRegistry ) )
.append( ')' );
}
private static void handleExtractFunction(
String placeholder,
Dialect dialect,
TypeConfiguration typeConfiguration,
SqmFunctionRegistry functionRegistry,
StringTokenizer tokens,
StringBuilder result) {
final String field = extractUntil( tokens, "from" );
final String source = renderWhereStringTemplate(
extractUntil( tokens, ")" ),
placeholder,
dialect,
typeConfiguration,
functionRegistry
);
result.append( "extract(" ).append( field ).append( " from " ).append( source ).append( ')' );
}
public static List<String> collectColumnNames(
String sql,
Dialect dialect,
@ -754,7 +748,7 @@ public final class Template {
}
private static String extractUntil(StringTokenizer tokens, String delimiter) {
StringBuilder valueBuilder = new StringBuilder();
final StringBuilder valueBuilder = new StringBuilder();
String token = tokens.nextToken();
while ( ! delimiter.equalsIgnoreCase( token ) ) {
valueBuilder.append( token );
@ -773,12 +767,21 @@ public final class Template {
Dialect dialect,
TypeConfiguration typeConfiguration,
SqmFunctionRegistry functionRegistry) {
return "(".equals( nextToken ) ||
KEYWORDS.contains( lcToken ) ||
isType( lcToken, typeConfiguration ) ||
isFunction( lcToken, nextToken, functionRegistry ) ||
dialect.getKeywords().contains( lcToken ) ||
FUNCTION_KEYWORDS.contains( lcToken );
if ( "(".equals( nextToken ) ) {
return true;
}
else if ( "date".equals( lcToken ) || "time".equals( lcToken ) ) {
// these can be column names on some databases
// TODO: treat 'current date' as a function
return false;
}
else {
return KEYWORDS.contains( lcToken )
|| isType( lcToken, typeConfiguration )
|| isFunction( lcToken, nextToken, functionRegistry )
|| dialect.getKeywords().contains( lcToken )
|| FUNCTION_KEYWORDS.contains( lcToken );
}
}
private static boolean isType(String lcToken, TypeConfiguration typeConfiguration) {
@ -800,10 +803,11 @@ public final class Template {
if ( isBoolean( token ) ) {
return false;
}
return token.charAt( 0 ) == '`' || ( //allow any identifier quoted with backtick
Character.isLetter( token.charAt( 0 ) ) && //only recognizes identifiers beginning with a letter
return token.charAt( 0 ) == '`'
|| ( //allow any identifier quoted with backtick
isLetter( token.charAt( 0 ) ) && //only recognizes identifiers beginning with a letter
token.indexOf( '.' ) < 0
);
);
}
private static boolean isBoolean(String token) {

View File

@ -7,13 +7,9 @@
package org.hibernate.orm.test.sql;
import org.hibernate.dialect.Dialect;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.query.sqm.function.SqmFunctionRegistry;
import org.hibernate.sql.Template;
import org.hibernate.type.spi.TypeConfiguration;
import org.hibernate.testing.orm.domain.StandardDomainModel;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.SessionFactory;
@ -23,23 +19,33 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
@SessionFactory
@DomainModel(standardModels = StandardDomainModel.GAMBIT)
@DomainModel
public class TemplateTest {
@Test
@JiraKey("HHH-18256")
public void templateLiterals(SessionFactoryScope scope) {
assertWhereStringTemplate( "N'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "X'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "BX'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "VARBYTE'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "bytea 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "bytea 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "date 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "time 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "timestamp 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "timestamp with time zone 'a'", scope.getSessionFactory() );
assertWhereStringTemplate( "time with time zone 'a'", scope.getSessionFactory() );
SessionFactoryImplementor factory = scope.getSessionFactory();
assertWhereStringTemplate( "N'a'", factory );
assertWhereStringTemplate( "X'a'", factory );
assertWhereStringTemplate( "BX'a'", factory);
assertWhereStringTemplate( "VARBYTE'a'", factory );
assertWhereStringTemplate( "bytea 'a'", factory );
assertWhereStringTemplate( "bytea 'a'", factory );
assertWhereStringTemplate( "date 'a'", factory );
assertWhereStringTemplate( "time 'a'", factory );
assertWhereStringTemplate( "timestamp 'a'", factory );
assertWhereStringTemplate( "timestamp with time zone 'a'", factory );
assertWhereStringTemplate( "time with time zone 'a'", factory );
assertWhereStringTemplate( "date", "$PlaceHolder$.date", factory );
assertWhereStringTemplate( "time", "$PlaceHolder$.time", factory );
assertWhereStringTemplate( "zone", "$PlaceHolder$.zone", factory );
assertWhereStringTemplate("select date from thetable",
"select $PlaceHolder$.date from thetable", factory );
assertWhereStringTemplate("select date '2000-12-1' from thetable",
"select date '2000-12-1' from thetable", factory );
assertWhereStringTemplate("where date between date '2000-12-1' and date '2002-12-2'",
"where $PlaceHolder$.date between date '2000-12-1' and date '2002-12-2'", factory );
}
private static void assertWhereStringTemplate(String sql, SessionFactoryImplementor sf) {
@ -52,4 +58,14 @@ public class TemplateTest {
assertEquals( sql, template );
}
private static void assertWhereStringTemplate(String sql, String result, SessionFactoryImplementor sf) {
final String template = Template.renderWhereStringTemplate(
sql,
sf.getJdbcServices().getDialect(),
sf.getTypeConfiguration(),
sf.getQueryEngine().getSqmFunctionRegistry()
);
assertEquals( result, template );
}
}