HHH-16867 - support index and join hints in the CockroachDB dialect

This commit is contained in:
Karel Maesen 2023-06-28 23:07:49 +02:00 committed by Christian Beikov
parent b7bdcd100f
commit 8df6d39b97
4 changed files with 423 additions and 0 deletions

View File

@ -14,6 +14,7 @@ import java.time.temporal.ChronoField;
import java.time.temporal.TemporalAccessor;
import java.util.Calendar;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.regex.Matcher;
@ -1079,6 +1080,33 @@ public class CockroachDialect extends Dialect {
};
}
/**
* Applies the hints to the query string.
*
* The hints can be <a href="https://www.cockroachlabs.com/docs/v23.1/table-expressions#force-index-selection">index selection hints</a>
* or <a href="https://www.cockroachlabs.com/docs/stable/sql-grammar#opt_join_hint">join hints</a>.
* <p>
* For index selection hints, use the format {@code <tablename>@{FORCE_INDEX=<index>[,<DIRECTION>]}}
* where the optional DIRECTION is either ASC (ascending) or DESC (descending). Multiple index hints can be provided.
* The effect is that in the final SQL statement the hint is added to the table name mentioned in the hint.
*<p>
* For join hints, use the format {@code "<MERGE|HASH|LOOKUP|INVERTED> JOIN"}. Only one join hint will be added. It is
* applied to all join statements in the SQL statement.
* <p>
* Hints are only added to select statements.
*
* @param query The query to which to apply the hint.
* @param hintList The hints to apply
*
* @return the query with hints added
*/
@Override
public String getQueryHintString(String query, List<String> hintList) {
return new CockroachDialectQueryHints(query, hintList).getQueryHintString();
}
// CockroachDB doesn't support this by default. See sql.multiple_modifications_of_table.enabled
//
// @Override

View File

@ -0,0 +1,131 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.dialect;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
class CockroachDialectQueryHints {
final private Pattern TABLE_QUERY_PATTERN = Pattern.compile(
"(?i)^\\s*(select\\b.+?\\bfrom\\b)(.+?)(\\bwhere\\b.+?)$" );
final private Pattern JOIN_HINT_PATTERN = Pattern.compile( "(?i)(MERGE|HASH|LOOKUP|INVERTED)\\s+JOIN" );
//If matched, group 1 contains everything before the join keyword.
final private Pattern JOIN_PATTERN = Pattern.compile(
"(?i)\\b(cross|natural\\s+(.*)\\b|(full|left|right)(\\s+outer)?)?\\s+join" );
final private String query;
final private List<String> hints;
public CockroachDialectQueryHints(String query, List<String> hintList) {
this.query = query;
this.hints = hintList;
}
public String getQueryHintString() {
List<IndexHint> indexHints = new ArrayList<>();
JoinHint joinHint = null;
for ( var h : hints ) {
IndexHint indexHint = parseIndexHints( h );
if ( indexHint != null ) {
indexHints.add( indexHint );
continue;
}
joinHint = parseJoinHints( h );
}
String result = addIndexHints( query, indexHints );
return joinHint == null ? result : addJoinHint( query, joinHint );
}
private IndexHint parseIndexHints(String hint) {
var parts = hint.split( "@" );
if ( parts.length == 2 ) {
return new IndexHint( parts[0], hint );
}
return null;
}
private JoinHint parseJoinHints(String hint) {
var matcher = JOIN_HINT_PATTERN.matcher( hint );
if ( matcher.find() ) {
return new JoinHint( matcher.group( 1 ) );
}
return null;
}
String addIndexHints(String query, List<IndexHint> hints) {
Matcher statementMatcher = TABLE_QUERY_PATTERN.matcher( query );
if ( statementMatcher.matches() && statementMatcher.groupCount() > 2 ) {
String prefix = statementMatcher.group( 1 );
String fromList = statementMatcher.group( 2 );
String suffix = statementMatcher.group( 3 );
fromList = addIndexHintsToFromList( fromList, hints );
return prefix + fromList + suffix;
}
else {
return query;
}
}
String addJoinHint(String query, JoinHint hint) {
var m = JOIN_PATTERN.matcher( query );
StringBuilder buffer = new StringBuilder();
int start = 0;
while ( m.find() ) {
buffer.append( query.substring( start, m.start() ) );
if ( m.group( 1 ) == null ) {
buffer.append( " inner" );
}
else {
buffer.append( m.group( 1 ) );
}
buffer.append( " " )
.append( hint.joinType )
.append( " join" );
start = m.end();
}
buffer.append( query.substring( start ) );
return buffer.toString();
}
String addIndexHintsToFromList(String fromList, List<IndexHint> hints) {
String result = fromList;
for ( var hint : hints ) {
result = result.replaceAll( "(?i)\\b" + hint.table + "\\b", hint.text );
}
return result;
}
static class IndexHint {
final String table;
final String text;
IndexHint(String table, String text) {
this.table = table;
this.text = text;
}
}
static class JoinHint {
final String joinType;
JoinHint(String type) {
this.joinType = type.toLowerCase( Locale.ROOT );
}
}
}

View File

@ -0,0 +1,182 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.dialect.functional;
import java.util.HashSet;
import java.util.Set;
import org.hibernate.dialect.CockroachDialect;
import org.hibernate.testing.jdbc.SQLStatementInspector;
import org.hibernate.testing.orm.junit.DomainModel;
import org.hibernate.testing.orm.junit.JiraKey;
import org.hibernate.testing.orm.junit.RequiresDialect;
import org.hibernate.testing.orm.junit.SessionFactory;
import org.hibernate.testing.orm.junit.SessionFactoryScope;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import jakarta.persistence.Entity;
import jakarta.persistence.Id;
import jakarta.persistence.Index;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.OneToMany;
import jakarta.persistence.Table;
import jakarta.persistence.TypedQuery;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
@RequiresDialect(CockroachDialect.class)
@SessionFactory(useCollectingStatementInspector = true)
@DomainModel(annotatedClasses = {
SimpleEntity.class, ChildEntity.class
})
@JiraKey("HHH-16867")
public class CockroachDBQueryHintsTest {
@BeforeAll
public void setUp(SessionFactoryScope scope) {
scope.inTransaction( session -> {
var se1 = new SimpleEntity( 1, "se1" );
se1.addChild( new ChildEntity( "se1child1" ) );
session.persist( se1 );
var se2 = new SimpleEntity( 2, "se2" );
session.persist( se2 );
var se3 = new SimpleEntity( 3, "se3" );
session.persist( se3 );
} );
}
@Test
public void testIndexHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<Integer> query = session.createQuery( "select id from SimpleEntity where id < 3", Integer.class )
.addQueryHint( "parents@{FORCE_INDEX=idx,DESC}" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" parents@{FORCE_INDEX=idx,DESC} " );
}
@Test
public void testJoinHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<ChildEntity> query = session.createQuery(
"select c from SimpleEntity s join s.children c where s.id < 3",
ChildEntity.class
)
.addQueryHint( "haSh join" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" hash join " );
}
@Test
public void testOuterJoinHint(SessionFactoryScope scope) {
final SQLStatementInspector statementInspector = scope.getCollectingStatementInspector();
statementInspector.clear();
scope.inTransaction( session -> {
TypedQuery<ChildEntity> query = session.createQuery(
"select c from SimpleEntity s left join s.children c where s.id < 3",
ChildEntity.class
)
.addQueryHint( "haSh join" );
var ignored = query.getResultList();
} );
assertThat( statementInspector.getSqlQueries().get( 0 ) ).contains(
" hash join " );
}
}
@Entity
@Table(name = "children")
class ChildEntity {
@Id
private Integer id;
private String childName;
@ManyToOne
@JoinColumn(name = "parent_id", nullable = false)
private SimpleEntity parent;
public ChildEntity() {
}
public ChildEntity(String childName) {
this.childName = childName;
}
public Integer getId() {
return id;
}
public void setId(Integer id) {
this.id = id;
}
public SimpleEntity getParent() {
return parent;
}
}
@Entity(name = "SimpleEntity")
@Table(name = "parents", indexes = { @Index(name = "idx", columnList = "id") })
class SimpleEntity {
@Id
private Integer id;
private String name;
@OneToMany(mappedBy = "parent")
private Set<ChildEntity> children;
public SimpleEntity() {
}
public SimpleEntity(Integer id, String name) {
this.id = id;
this.name = name;
this.children = new HashSet<>();
}
public Integer getId() {
return id;
}
public void setId(Integer id) {
this.id = id;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public Set<ChildEntity> getChildren() {
return children;
}
public void setChildren(Set<ChildEntity> children) {
this.children = children;
}
public void addChild(ChildEntity child) {
this.children.add( child );
}
}

View File

@ -0,0 +1,82 @@
/*
* Hibernate, Relational Persistence for Idiomatic Java
*
* License: GNU Lesser General Public License (LGPL), version 2.1 or later.
* See the lgpl.txt file in the root directory or <http://www.gnu.org/licenses/lgpl-2.1.html>.
*/
package org.hibernate.orm.test.dialect.unit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertNull;
/**
* Parses join expressions for CockroachDBDialect.
* <p>
* We need to (re)parse the join expression in order to add join hints
* to the generated SQL statement.
*
* @author Karel Maesen
*/
public class CockroachDialectParseJoinExpression {
final private Pattern JOIN_PATTERN = Pattern.compile(
"(?i)\\b(cross|natural\\s+(.*)\\b|(full|left|right)(\\s+outer)?)?\\s+join" );
@Test
public void testSimpleJoin() {
Matcher m = JOIN_PATTERN.matcher( "abc join def" );
if ( m.find() ) {
assertNull( m.group( 1 ) );
}
else {
Assertions.fail( "No match" );
}
}
@Test
public void testCross() {
testJoinMatch( "CRoss join", 1, "cross" );
}
@Test
public void testNatural() {
testJoinMatch( "natural left outer join", 1, "natural left outer" );
}
@Test
public void testLeftOuterJoin() {
testJoinMatch( "left outer join", 1, "left outer" );
}
@Test
public void testLeftJoin() {
testJoinMatch( "left join", 1, "left" );
}
@Test
public void testNaturalJoinType() {
testJoinMatch( "natural left outer join", 2, "left outer" );
}
public void testJoinMatch(String input, int group, String expects) {
testMatch( JOIN_PATTERN, input, group, expects );
}
public void testMatch(Pattern pattern, String input, int group, String expects) {
Matcher m = pattern.matcher( input );
if ( m.find() ) {
Assertions.assertEquals( expects, m.group( group ).toLowerCase() );
}
else {
Assertions.fail( "No match" );
}
}
}