diff --git a/tooling/metamodel-generator/src/main/java/org/hibernate/jpamodelgen/ImportContextImpl.java b/tooling/metamodel-generator/src/main/java/org/hibernate/jpamodelgen/ImportContextImpl.java index 0b8c106d3d..0d778f40c0 100644 --- a/tooling/metamodel-generator/src/main/java/org/hibernate/jpamodelgen/ImportContextImpl.java +++ b/tooling/metamodel-generator/src/main/java/org/hibernate/jpamodelgen/ImportContextImpl.java @@ -11,6 +11,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.TreeSet; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.hibernate.jpamodelgen.model.ImportContext; @@ -22,6 +24,8 @@ import org.hibernate.jpamodelgen.model.ImportContext; */ public class ImportContextImpl implements ImportContext { + private static final Pattern P_IMPORT = Pattern.compile( "(\\?\\s+\\w+\\s+)(.+)" ); + private final Set imports = new TreeSet<>(); private final Set staticImports = new TreeSet<>(); private final Map simpleNames = new HashMap<>(); @@ -66,13 +70,20 @@ public class ImportContextImpl implements ImportContext { // strip off type annotations and '? super' or '? extends' String preamble = ""; - if ( result.startsWith("@") || result.startsWith("?") ) { + if ( result.startsWith( "@" ) ) { int index = result.lastIndexOf(' '); if ( index > 0 ) { preamble = result.substring( 0, index+1 ); result = result.substring( index+1 ); } } + else if ( result.startsWith( "?" ) ) { + final Matcher m = P_IMPORT.matcher( result ); + if ( m.matches() ) { + preamble = m.group( 1 ); + result = Objects.requireNonNullElse( m.group( 2 ), "" ); + } + } String appendices = ""; if ( result.indexOf( '<' ) >= 0 ) { @@ -123,11 +134,24 @@ public class ImportContextImpl implements ImportContext { private String importTypes(String originalArgList) { String[] args = originalArgList.split(","); StringBuilder argList = new StringBuilder(); + StringBuilder acc = new StringBuilder(); for ( String arg : args ) { + if ( acc.length() > 0 ) { + acc.append( ',' ); + } + acc.append( arg ); + final int count = acc.chars().reduce( + 0, + (left, right) -> left + ( right == '<' ? 1 : right == '>' ? -1 : 0 ) + ); + if ( count > 0 ) { + continue; + } if ( argList.length() > 0 ) { argList.append(','); } - argList.append( importType( arg ) ); + argList.append( importType( acc.toString() ) ); + acc.setLength( 0 ); } return argList.toString(); }