fix issue with boolean expression input (#12429)

This commit is contained in:
Clint Wylie 2022-04-13 16:34:01 -07:00 committed by GitHub
parent 5d37d9f9d8
commit 5824ab9608
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 351 additions and 28 deletions

View File

@ -130,19 +130,19 @@ public abstract class ExprEval<T>
if (coercedType == Long.class || coercedType == Integer.class) {
return new NonnullPair<>(
ExpressionType.LONG_ARRAY,
val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray()
val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.LONG, x).value() : null).toArray()
);
}
if (coercedType == Float.class || coercedType == Double.class) {
return new NonnullPair<>(
ExpressionType.DOUBLE_ARRAY,
val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray()
val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.DOUBLE, x).value() : null).toArray()
);
}
// default to string
return new NonnullPair<>(
ExpressionType.STRING_ARRAY,
val.stream().map(x -> x != null ? x.toString() : null).toArray()
val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.STRING, x).value() : null).toArray()
);
}
if (homogenizeMultiValueStrings) {
@ -194,7 +194,7 @@ public abstract class ExprEval<T>
*/
private static Class convertType(@Nullable Class existing, Class next)
{
if (Number.class.isAssignableFrom(next) || next == String.class) {
if (Number.class.isAssignableFrom(next) || next == String.class || next == Boolean.class) {
if (existing == null) {
return next;
}
@ -348,6 +348,12 @@ public abstract class ExprEval<T>
}
return new LongExprEval((Number) val);
}
if (val instanceof Boolean) {
if (ExpressionProcessing.useStrictBooleans()) {
return ofLongBoolean((Boolean) val);
}
return new StringExprEval(String.valueOf(val));
}
if (val instanceof Long[]) {
return new ArrayExprEval(ExpressionType.LONG_ARRAY, (Long[]) val);
}
@ -360,20 +366,13 @@ public abstract class ExprEval<T>
if (val instanceof String[]) {
return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) val);
}
if (val instanceof Object[]) {
ExpressionType arrayType = findArrayType((Object[]) val);
if (arrayType != null) {
return new ArrayExprEval(arrayType, (Object[]) val);
}
// default to string if array is empty
return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) val);
}
if (val instanceof List) {
if (val instanceof List || val instanceof Object[]) {
final List<?> theList = val instanceof List ? ((List<?>) val) : Arrays.asList((Object[]) val);
// do not convert empty lists to arrays with a single null element here, because that should have been done
// by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately
// empty
NonnullPair<ExpressionType, Object[]> coerced = coerceListToArray((List<?>) val, false);
NonnullPair<ExpressionType, Object[]> coerced = coerceListToArray(theList, false);
if (coerced == null) {
return bestEffortOf(null);
}
@ -400,7 +399,7 @@ public abstract class ExprEval<T>
return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) value);
}
if (value instanceof Object[]) {
return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) value);
return bestEffortOf(value);
}
if (value instanceof List) {
return bestEffortOf(value);
@ -413,6 +412,9 @@ public abstract class ExprEval<T>
if (value instanceof Number) {
return ofLong((Number) value);
}
if (value instanceof Boolean) {
return ofLongBoolean((Boolean) value);
}
if (value instanceof String) {
return ofLong(ExprEval.computeNumber((String) value));
}
@ -421,6 +423,12 @@ public abstract class ExprEval<T>
if (value instanceof Number) {
return ofDouble((Number) value);
}
if (value instanceof Boolean) {
if (ExpressionProcessing.useStrictBooleans()) {
return ofLongBoolean((Boolean) value);
}
return ofDouble(Evals.asDouble((Boolean) value));
}
if (value instanceof String) {
return ofDouble(ExprEval.computeNumber((String) value));
}
@ -442,13 +450,14 @@ public abstract class ExprEval<T>
return ofComplex(type, value);
case ARRAY:
if (value instanceof Object[]) {
// nested arrays, here be dragons... don't do any fancy coercion, assume everything is already sane types...
if (type.getElementType().isArray()) {
return ofArray(type, (Object[]) value);
}
// in a better world, we might get an object that matches the type signature for arrays and could do a switch
// statement here, but this is not that world yet, and things that are array typed might also be non-arrays,
// e.g. we might get a String instead of String[], so just fallback to bestEffortOf
return bestEffortOf(value);
return bestEffortOf(value).castTo(type);
}
throw new IAE("Cannot create type [%s]", type);
}
@ -459,6 +468,12 @@ public abstract class ExprEval<T>
if (value == null) {
return null;
}
if (Evals.asBoolean(value)) {
return 1.0;
}
if (value.equalsIgnoreCase("false")) {
return 0.0;
}
Number rv;
Long v = GuavaUtils.tryParseLong(value);
// Do NOT use ternary operator here, because it makes Java to convert Long to Double

View File

@ -22,12 +22,20 @@ package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.TypeStrategies;
import org.apache.druid.segment.column.TypeStrategiesTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
@ -35,6 +43,16 @@ import static org.junit.Assert.assertNull;
*/
public class EvalTest extends InitializedNullHandlingTest
{
@BeforeClass
public static void setupClass()
{
TypeStrategies.registerComplex(
TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(),
new TypeStrategiesTest.NullableLongPairTypeStrategy()
);
}
@Rule
public ExpectedException expectedException = ExpectedException.none();
@ -576,4 +594,224 @@ public class EvalTest extends InitializedNullHandlingTest
ExpressionProcessing.initializeForTests(null);
}
}
@Test
public void testBooleanInputs()
{
Map<String, Object> bindingsMap = new HashMap<>();
bindingsMap.put("l1", 100L);
bindingsMap.put("l2", 0L);
bindingsMap.put("d1", 1.1);
bindingsMap.put("d2", 0.0);
bindingsMap.put("s1", "true");
bindingsMap.put("s2", "false");
bindingsMap.put("b1", true);
bindingsMap.put("b2", false);
Expr.ObjectBinding bindings = InputBindings.withMap(bindingsMap);
try {
ExpressionProcessing.initializeForStrictBooleansTests(true);
assertEquals(1L, eval("s1 && s1", bindings).value());
assertEquals(0L, eval("s1 && s2", bindings).value());
assertEquals(0L, eval("s2 && s1", bindings).value());
assertEquals(0L, eval("s2 && s2", bindings).value());
assertEquals(1L, eval("s1 || s1", bindings).value());
assertEquals(1L, eval("s1 || s2", bindings).value());
assertEquals(1L, eval("s2 || s1", bindings).value());
assertEquals(0L, eval("s2 || s2", bindings).value());
assertEquals(1L, eval("l1 && l1", bindings).value());
assertEquals(0L, eval("l1 && l2", bindings).value());
assertEquals(0L, eval("l2 && l1", bindings).value());
assertEquals(0L, eval("l2 && l2", bindings).value());
assertEquals(1L, eval("b1 && b1", bindings).value());
assertEquals(0L, eval("b1 && b2", bindings).value());
assertEquals(0L, eval("b2 && b1", bindings).value());
assertEquals(0L, eval("b2 && b2", bindings).value());
assertEquals(1L, eval("d1 && d1", bindings).value());
assertEquals(0L, eval("d1 && d2", bindings).value());
assertEquals(0L, eval("d2 && d1", bindings).value());
assertEquals(0L, eval("d2 && d2", bindings).value());
assertEquals(1L, eval("b1", bindings).value());
assertEquals(1L, eval("if(b1,1,0)", bindings).value());
assertEquals(1L, eval("if(l1,1,0)", bindings).value());
assertEquals(1L, eval("if(d1,1,0)", bindings).value());
assertEquals(1L, eval("if(s1,1,0)", bindings).value());
assertEquals(0L, eval("if(b2,1,0)", bindings).value());
assertEquals(0L, eval("if(l2,1,0)", bindings).value());
assertEquals(0L, eval("if(d2,1,0)", bindings).value());
assertEquals(0L, eval("if(s2,1,0)", bindings).value());
}
finally {
// reset
ExpressionProcessing.initializeForTests(null);
}
try {
// turn on legacy insanity mode
ExpressionProcessing.initializeForStrictBooleansTests(false);
assertEquals("true", eval("s1 && s1", bindings).value());
assertEquals("false", eval("s1 && s2", bindings).value());
assertEquals("false", eval("s2 && s1", bindings).value());
assertEquals("false", eval("s2 && s2", bindings).value());
assertEquals("true", eval("b1 && b1", bindings).value());
assertEquals("false", eval("b1 && b2", bindings).value());
assertEquals("false", eval("b2 && b1", bindings).value());
assertEquals("false", eval("b2 && b2", bindings).value());
assertEquals(100L, eval("l1 && l1", bindings).value());
assertEquals(0L, eval("l1 && l2", bindings).value());
assertEquals(0L, eval("l2 && l1", bindings).value());
assertEquals(0L, eval("l2 && l2", bindings).value());
assertEquals(1.1, eval("d1 && d1", bindings).value());
assertEquals(0.0, eval("d1 && d2", bindings).value());
assertEquals(0.0, eval("d2 && d1", bindings).value());
assertEquals(0.0, eval("d2 && d2", bindings).value());
assertEquals("true", eval("b1", bindings).value());
assertEquals(1L, eval("if(b1,1,0)", bindings).value());
assertEquals(1L, eval("if(l1,1,0)", bindings).value());
assertEquals(1L, eval("if(d1,1,0)", bindings).value());
assertEquals(1L, eval("if(s1,1,0)", bindings).value());
assertEquals(0L, eval("if(b2,1,0)", bindings).value());
assertEquals(0L, eval("if(l2,1,0)", bindings).value());
assertEquals(0L, eval("if(d2,1,0)", bindings).value());
assertEquals(0L, eval("if(s2,1,0)", bindings).value());
}
finally {
// reset
ExpressionProcessing.initializeForTests(null);
}
}
@Test
public void testEvalOfType()
{
// strings
ExprEval eval = ExprEval.ofType(ExpressionType.STRING, "stringy");
Assert.assertEquals(ExpressionType.STRING, eval.type());
Assert.assertEquals("stringy", eval.value());
eval = ExprEval.ofType(ExpressionType.STRING, 1L);
Assert.assertEquals(ExpressionType.STRING, eval.type());
Assert.assertEquals("1", eval.value());
eval = ExprEval.ofType(ExpressionType.STRING, 1.0);
Assert.assertEquals(ExpressionType.STRING, eval.type());
Assert.assertEquals("1.0", eval.value());
eval = ExprEval.ofType(ExpressionType.STRING, true);
Assert.assertEquals(ExpressionType.STRING, eval.type());
Assert.assertEquals("true", eval.value());
// longs
eval = ExprEval.ofType(ExpressionType.LONG, 1L);
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertEquals(1L, eval.value());
eval = ExprEval.ofType(ExpressionType.LONG, 1.0);
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertEquals(1L, eval.value());
eval = ExprEval.ofType(ExpressionType.LONG, "1");
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertEquals(1L, eval.value());
eval = ExprEval.ofType(ExpressionType.LONG, true);
Assert.assertEquals(ExpressionType.LONG, eval.type());
Assert.assertEquals(1L, eval.value());
// doubles
eval = ExprEval.ofType(ExpressionType.DOUBLE, 1L);
Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
Assert.assertEquals(1.0, eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE, 1.0);
Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
Assert.assertEquals(1.0, eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE, "1");
Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
Assert.assertEquals(1.0, eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE, true);
Assert.assertEquals(ExpressionType.DOUBLE, eval.type());
Assert.assertEquals(1.0, eval.value());
// complex
TypeStrategiesTest.NullableLongPair pair = new TypeStrategiesTest.NullableLongPair(1L, 2L);
ExpressionType type = ExpressionType.fromColumnType(TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE);
eval = ExprEval.ofType(type, pair);
Assert.assertEquals(type, eval.type());
Assert.assertEquals(pair, eval.value());
ByteBuffer buffer = ByteBuffer.allocate(TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getStrategy().estimateSizeBytes(pair));
TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getStrategy().write(buffer, pair, buffer.limit());
byte[] pairBytes = buffer.array();
eval = ExprEval.ofType(type, pairBytes);
Assert.assertEquals(type, eval.type());
Assert.assertEquals(pair, eval.value());
eval = ExprEval.ofType(type, StringUtils.encodeBase64String(pairBytes));
Assert.assertEquals(type, eval.type());
Assert.assertEquals(pair, eval.value());
// arrays fall back to using 'bestEffortOf', but cast it to the expected output type
eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {"1", "2", "3"});
Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L});
Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2.0, 3.0});
Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2L, "3", true, false});
Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1L, 2L, 3L, 1L, 0L}, (Object[]) eval.value());
// etc
eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {"1", "2", "3"});
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1L, 2L, 3L});
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0});
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2L, "3", true, false});
Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0, 1.0, 0.0}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"1", "2", "3"});
Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {"1", "2", "3"}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1L, 2L, 3L});
Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {"1", "2", "3"}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1.0, 2.0, 3.0});
Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {"1.0", "2.0", "3.0"}, (Object[]) eval.value());
eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1.0, 2L, "3", true, false});
Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type());
Assert.assertArrayEquals(new Object[] {"1.0", "2", "3", "true", "false"}, (Object[]) eval.value());
}
}

View File

@ -254,7 +254,7 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli
ImmutableSet.of("x"),
null,
"0",
"ARRAY<LONG>[]",
"ARRAY<STRING>[]",
true,
true,
false,

View File

@ -30,6 +30,7 @@ import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.filter.AndDimFilter;
import org.apache.druid.query.filter.SelectorDimFilter;
@ -54,6 +55,7 @@ public class TransformSpecTest extends InitializedNullHandlingTest
.put("y", "bar")
.put("a", 2.0)
.put("b", 3L)
.put("bool", true)
.build();
private static final Map<String, Object> ROW2 = ImmutableMap.<String, Object>builder()
@ -61,6 +63,7 @@ public class TransformSpecTest extends InitializedNullHandlingTest
.put("y", "baz")
.put("a", 2.0)
.put("b", 4L)
.put("bool", false)
.build();
@Test
@ -202,6 +205,73 @@ public class TransformSpecTest extends InitializedNullHandlingTest
Assert.assertEquals(DateTimes.of("2000-01-01T01:00:00Z").getMillis(), row.getTimestampFromEpoch());
}
@Test
public void testBoolTransforms()
{
try {
ExpressionProcessing.initializeForStrictBooleansTests(true);
final TransformSpec transformSpec = new TransformSpec(
null,
ImmutableList.of(
new ExpressionTransform("truthy1", "bool", TestExprMacroTable.INSTANCE),
new ExpressionTransform("truthy2", "if(bool,1,0)", TestExprMacroTable.INSTANCE)
)
);
Assert.assertEquals(
ImmutableSet.of("bool"),
transformSpec.getRequiredColumns()
);
final InputRowParser<Map<String, Object>> parser = transformSpec.decorate(PARSER);
final InputRow row = parser.parseBatch(ROW1).get(0);
Assert.assertNotNull(row);
Assert.assertEquals(1L, row.getRaw("truthy1"));
Assert.assertEquals(1L, row.getRaw("truthy2"));
final InputRow row2 = parser.parseBatch(ROW2).get(0);
Assert.assertNotNull(row2);
Assert.assertEquals(0L, row2.getRaw("truthy1"));
Assert.assertEquals(0L, row2.getRaw("truthy2"));
}
finally {
ExpressionProcessing.initializeForTests(null);
}
try {
ExpressionProcessing.initializeForStrictBooleansTests(false);
final TransformSpec transformSpec = new TransformSpec(
null,
ImmutableList.of(
new ExpressionTransform("truthy1", "bool", TestExprMacroTable.INSTANCE),
new ExpressionTransform("truthy2", "if(bool,1,0)", TestExprMacroTable.INSTANCE)
)
);
Assert.assertEquals(
ImmutableSet.of("bool"),
transformSpec.getRequiredColumns()
);
final InputRowParser<Map<String, Object>> parser = transformSpec.decorate(PARSER);
final InputRow row = parser.parseBatch(ROW1).get(0);
Assert.assertNotNull(row);
Assert.assertEquals("true", row.getRaw("truthy1"));
Assert.assertEquals(1L, row.getRaw("truthy2"));
final InputRow row2 = parser.parseBatch(ROW2).get(0);
Assert.assertNotNull(row2);
Assert.assertEquals("false", row2.getRaw("truthy1"));
Assert.assertEquals(0L, row2.getRaw("truthy2"));
}
finally {
ExpressionProcessing.initializeForTests(null);
}
}
@Test
public void testSerde() throws Exception
{