Fix casting bug in compound assignment for String (#65329)

This change fixes a bug where when doing compound assignment involving String concatenation, the 
right-hand side will fail to cast to String appropriately and throw a ClassCastException.
This commit is contained in:
Jack Conradson 2020-11-20 12:04:28 -08:00
parent 56605e4d9a
commit 89ec7db26b
2 changed files with 94 additions and 1 deletions

View File

@ -935,6 +935,7 @@ public class DefaultSemanticAnalysisPhase extends UserTreeBaseVisitor<SemanticSc
Class<?> rightValueType = semanticScope.getDecoration(userRightNode, ValueType.class).getValueType();
Class<?> compoundType;
boolean isConcatenation = false;
Class<?> shiftType = null;
boolean isShift = false;
@ -946,6 +947,7 @@ public class DefaultSemanticAnalysisPhase extends UserTreeBaseVisitor<SemanticSc
compoundType = AnalyzerCaster.promoteNumeric(leftValueType, rightValueType, true);
} else if (operation == Operation.ADD) {
compoundType = AnalyzerCaster.promoteAdd(leftValueType, rightValueType);
isConcatenation = compoundType == String.class;
} else if (operation == Operation.SUB) {
compoundType = AnalyzerCaster.promoteNumeric(leftValueType, rightValueType, true);
} else if (operation == Operation.LSH) {
@ -975,7 +977,9 @@ public class DefaultSemanticAnalysisPhase extends UserTreeBaseVisitor<SemanticSc
"cannot apply [" + operation.symbol + "=] to types [" + leftValueType + "] and [" + rightValueType + "]"));
}
if (isShift) {
if (isConcatenation) {
semanticScope.putDecoration(userRightNode, new TargetType(rightValueType));
} else if (isShift) {
if (compoundType == def.class) {
// shifts are promoted independently, but for the def type, we need object.
semanticScope.putDecoration(userRightNode, new TargetType(def.class));

View File

@ -345,4 +345,93 @@ public class GeneralCastTests extends ScriptTestCase {
expectScriptThrows(ClassCastException.class, () -> exec("def x = 2.0; def y = 1; y.compareTo(x);"));
expectScriptThrows(ClassCastException.class, () -> exec("float f = 1.0f; def y = 1; y.compareTo(f);"));
}
public void testCompoundAssignmentStringCasts() {
assertEquals("s71", exec("String s = 's'; byte c = 71; s += c; return s"));
assertEquals("s71", exec("String s = 's'; short c = 71; s += c; return s"));
assertEquals("sG", exec("String s = 's'; char c = 71; s += c; return s"));
assertEquals("s71", exec("String s = 's'; int c = 71; s += c; return s"));
assertEquals("s71", exec("String s = 's'; long c = 71; s += c; return s"));
assertEquals("s71.0", exec("String s = 's'; float c = 71; s += c; return s"));
assertEquals("s71.0", exec("String s = 's'; double c = 71; s += c; return s"));
assertEquals("s71", exec("String s = 's'; String c = '71'; s += c; return s"));
assertEquals("s[71]", exec("String s = 's'; List c = [71]; s += c; return s"));
assertEquals("s71s", exec("String s = 's'; byte c = 71; s += c + s; return s"));
assertEquals("s71s", exec("String s = 's'; short c = 71; s += c + s; return s"));
assertEquals("sGs", exec("String s = 's'; char c = 71; s += c + s; return s"));
assertEquals("s71s", exec("String s = 's'; int c = 71; s += c + s; return s"));
assertEquals("s71s", exec("String s = 's'; long c = 71; s += c + s; return s"));
assertEquals("s71.0s", exec("String s = 's'; float c = 71; s += c + s; return s"));
assertEquals("s71.0s", exec("String s = 's'; double c = 71; s += c + s; return s"));
assertEquals("s71s", exec("String s = 's'; String c = '71'; s += c + s; return s"));
assertEquals("s[71]s", exec("String s = 's'; List c = [71]; s += c + s; return s"));
assertEquals("s142", exec("String s = 's'; byte c = 71; s += c + c; return s"));
assertEquals("s142", exec("String s = 's'; short c = 71; s += c + c; return s"));
assertEquals("s142", exec("String s = 's'; char c = 71; s += c + c; return s"));
assertEquals("s142", exec("String s = 's'; int c = 71; s += c + c; return s"));
assertEquals("s142", exec("String s = 's'; long c = 71; s += c + c; return s"));
assertEquals("s142.0", exec("String s = 's'; float c = 71; s += c + c; return s"));
assertEquals("s142.0", exec("String s = 's'; double c = 71; s += c + c; return s"));
assertEquals("s7171", exec("String s = 's'; String c = '71'; s += c + c; return s"));
assertEquals("s7171", exec("String s = 's'; byte c = 71; s += c + '' + c; return s"));
assertEquals("s7171", exec("String s = 's'; short c = 71; s += c + '' + c; return s"));
assertEquals("sGG", exec("String s = 's'; char c = 71; s += c + '' + c; return s"));
assertEquals("s7171", exec("String s = 's'; int c = 71; s += c + '' + c; return s"));
assertEquals("s7171", exec("String s = 's'; long c = 71; s += c + '' + c; return s"));
assertEquals("s71.071.0", exec("String s = 's'; float c = 71; s += c + '' + c; return s"));
assertEquals("s71.071.0", exec("String s = 's'; double c = 71; s += c + '' + c; return s"));
assertEquals("s7171", exec("String s = 's'; String c = '71'; s += c + '' + c; return s"));
assertEquals("s[71][71]", exec("String s = 's'; List c = [71]; s += c + '' + c; return s"));
assertEquals("s142", exec("String s = 's'; byte c = 71; s += c + c + ''; return s"));
assertEquals("s142", exec("String s = 's'; short c = 71; s += c + c + ''; return s"));
assertEquals("s142", exec("String s = 's'; char c = 71; s += c + c + ''; return s"));
assertEquals("s142", exec("String s = 's'; int c = 71; s += c + c + ''; return s"));
assertEquals("s142", exec("String s = 's'; long c = 71; s += c + c + ''; return s"));
assertEquals("s142.0", exec("String s = 's'; float c = 71; s += c + c + ''; return s"));
assertEquals("s142.0", exec("String s = 's'; double c = 71; s += c + c + ''; return s"));
assertEquals("s7171", exec("String s = 's'; String c = '71'; s += c + c + ''; return s"));
assertEquals("s7171", exec("String s = 's'; byte c = 71; s += '' + c + c; return s"));
assertEquals("s7171", exec("String s = 's'; short c = 71; s += '' + c + c; return s"));
assertEquals("sGG", exec("String s = 's'; char c = 71; s += '' + c + c; return s"));
assertEquals("s7171", exec("String s = 's'; int c = 71; s += '' + c + c; return s"));
assertEquals("s7171", exec("String s = 's'; long c = 71; s += '' + c + c; return s"));
assertEquals("s71.071.0", exec("String s = 's'; float c = 71; s += '' + c + c; return s"));
assertEquals("s71.071.0", exec("String s = 's'; double c = 71; s += '' + c + c; return s"));
assertEquals("s7171", exec("String s = 's'; String c = '71'; s += '' + c + c; return s"));
assertEquals("s[71][71]", exec("String s = 's'; List c = [71]; s += '' + c + c; return s"));
assertEquals("s71s71", exec("String s = 's'; byte c = 71; s += c + s + c; return s"));
assertEquals("s71s71", exec("String s = 's'; short c = 71; s += c + s + c; return s"));
assertEquals("sGsG", exec("String s = 's'; char c = 71; s += c + s + c; return s"));
assertEquals("s71s71", exec("String s = 's'; int c = 71; s += c + s + c; return s"));
assertEquals("s71s71", exec("String s = 's'; long c = 71; s += c + s + c; return s"));
assertEquals("s71.0s71.0", exec("String s = 's'; float c = 71; s += c + s + c; return s"));
assertEquals("s71.0s71.0", exec("String s = 's'; double c = 71; s += c + s + c; return s"));
assertEquals("s71s71", exec("String s = 's'; String c = '71'; s += c + s + c; return s"));
assertEquals("s[71]s[71]", exec("String s = 's'; List c = [71]; s += c + s + c; return s"));
assertEquals("s142s", exec("String s = 's'; byte c = 71; s += c + c + s; return s"));
assertEquals("s142s", exec("String s = 's'; short c = 71; s += c + c + s; return s"));
assertEquals("s142s", exec("String s = 's'; char c = 71; s += c + c + s; return s"));
assertEquals("s142s", exec("String s = 's'; int c = 71; s += c + c + s; return s"));
assertEquals("s142s", exec("String s = 's'; long c = 71; s += c + c + s; return s"));
assertEquals("s142.0s", exec("String s = 's'; float c = 71; s += c + c + s; return s"));
assertEquals("s142.0s", exec("String s = 's'; double c = 71; s += c + c + s; return s"));
assertEquals("s7171s", exec("String s = 's'; String c = '71'; s += c + c + s; return s"));
assertEquals("ss7171", exec("String s = 's'; byte c = 71; s += s + c + c; return s"));
assertEquals("ss7171", exec("String s = 's'; short c = 71; s += s + c + c; return s"));
assertEquals("ssGG", exec("String s = 's'; char c = 71; s += s + c + c; return s"));
assertEquals("ss7171", exec("String s = 's'; int c = 71; s += s + c + c; return s"));
assertEquals("ss7171", exec("String s = 's'; long c = 71; s += s + c + c; return s"));
assertEquals("ss71.071.0", exec("String s = 's'; float c = 71; s += s + c + c; return s"));
assertEquals("ss71.071.0", exec("String s = 's'; double c = 71; s += s + c + c; return s"));
assertEquals("ss7171", exec("String s = 's'; String c = '71'; s += s + c + c; return s"));
assertEquals("ss[71][71]", exec("String s = 's'; List c = [71]; s += s + c + c; return s"));
}
}