diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 158eb3d9c8d..3152fb26d5a 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -130,19 +130,19 @@ public abstract class ExprEval if (coercedType == Long.class || coercedType == Integer.class) { return new NonnullPair<>( ExpressionType.LONG_ARRAY, - val.stream().map(x -> x != null ? ((Number) x).longValue() : null).toArray() + val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.LONG, x).value() : null).toArray() ); } if (coercedType == Float.class || coercedType == Double.class) { return new NonnullPair<>( ExpressionType.DOUBLE_ARRAY, - val.stream().map(x -> x != null ? ((Number) x).doubleValue() : null).toArray() + val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.DOUBLE, x).value() : null).toArray() ); } // default to string return new NonnullPair<>( ExpressionType.STRING_ARRAY, - val.stream().map(x -> x != null ? x.toString() : null).toArray() + val.stream().map(x -> x != null ? ExprEval.ofType(ExpressionType.STRING, x).value() : null).toArray() ); } if (homogenizeMultiValueStrings) { @@ -194,7 +194,7 @@ public abstract class ExprEval */ private static Class convertType(@Nullable Class existing, Class next) { - if (Number.class.isAssignableFrom(next) || next == String.class) { + if (Number.class.isAssignableFrom(next) || next == String.class || next == Boolean.class) { if (existing == null) { return next; } @@ -348,6 +348,12 @@ public abstract class ExprEval } return new LongExprEval((Number) val); } + if (val instanceof Boolean) { + if (ExpressionProcessing.useStrictBooleans()) { + return ofLongBoolean((Boolean) val); + } + return new StringExprEval(String.valueOf(val)); + } if (val instanceof Long[]) { return new ArrayExprEval(ExpressionType.LONG_ARRAY, (Long[]) val); } @@ -360,20 +366,13 @@ public abstract class ExprEval if (val instanceof String[]) { return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) val); } - if (val instanceof Object[]) { - ExpressionType arrayType = findArrayType((Object[]) val); - if (arrayType != null) { - return new ArrayExprEval(arrayType, (Object[]) val); - } - // default to string if array is empty - return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) val); - } - if (val instanceof List) { + if (val instanceof List || val instanceof Object[]) { + final List theList = val instanceof List ? ((List) val) : Arrays.asList((Object[]) val); // do not convert empty lists to arrays with a single null element here, because that should have been done // by the selectors preparing their ObjectBindings if necessary. If we get to this point it was legitimately // empty - NonnullPair coerced = coerceListToArray((List) val, false); + NonnullPair coerced = coerceListToArray(theList, false); if (coerced == null) { return bestEffortOf(null); } @@ -400,7 +399,7 @@ public abstract class ExprEval return new ArrayExprEval(ExpressionType.STRING_ARRAY, (String[]) value); } if (value instanceof Object[]) { - return new ArrayExprEval(ExpressionType.STRING_ARRAY, (Object[]) value); + return bestEffortOf(value); } if (value instanceof List) { return bestEffortOf(value); @@ -413,6 +412,9 @@ public abstract class ExprEval if (value instanceof Number) { return ofLong((Number) value); } + if (value instanceof Boolean) { + return ofLongBoolean((Boolean) value); + } if (value instanceof String) { return ofLong(ExprEval.computeNumber((String) value)); } @@ -421,6 +423,12 @@ public abstract class ExprEval if (value instanceof Number) { return ofDouble((Number) value); } + if (value instanceof Boolean) { + if (ExpressionProcessing.useStrictBooleans()) { + return ofLongBoolean((Boolean) value); + } + return ofDouble(Evals.asDouble((Boolean) value)); + } if (value instanceof String) { return ofDouble(ExprEval.computeNumber((String) value)); } @@ -442,13 +450,14 @@ public abstract class ExprEval return ofComplex(type, value); case ARRAY: - if (value instanceof Object[]) { + // nested arrays, here be dragons... don't do any fancy coercion, assume everything is already sane types... + if (type.getElementType().isArray()) { return ofArray(type, (Object[]) value); } // in a better world, we might get an object that matches the type signature for arrays and could do a switch // statement here, but this is not that world yet, and things that are array typed might also be non-arrays, // e.g. we might get a String instead of String[], so just fallback to bestEffortOf - return bestEffortOf(value); + return bestEffortOf(value).castTo(type); } throw new IAE("Cannot create type [%s]", type); } @@ -459,6 +468,12 @@ public abstract class ExprEval if (value == null) { return null; } + if (Evals.asBoolean(value)) { + return 1.0; + } + if (value.equalsIgnoreCase("false")) { + return 0.0; + } Number rv; Long v = GuavaUtils.tryParseLong(value); // Do NOT use ternary operator here, because it makes Java to convert Long to Double diff --git a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java index 4b85836f7ce..b15f5f81063 100644 --- a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -22,12 +22,20 @@ package org.apache.druid.math.expr; import com.google.common.collect.ImmutableMap; import org.apache.druid.common.config.NullHandling; import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.segment.column.TypeStrategies; +import org.apache.druid.segment.column.TypeStrategiesTest; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; +import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -35,6 +43,16 @@ import static org.junit.Assert.assertNull; */ public class EvalTest extends InitializedNullHandlingTest { + + @BeforeClass + public static void setupClass() + { + TypeStrategies.registerComplex( + TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getComplexTypeName(), + new TypeStrategiesTest.NullableLongPairTypeStrategy() + ); + } + @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -576,4 +594,224 @@ public class EvalTest extends InitializedNullHandlingTest ExpressionProcessing.initializeForTests(null); } } + + @Test + public void testBooleanInputs() + { + Map bindingsMap = new HashMap<>(); + bindingsMap.put("l1", 100L); + bindingsMap.put("l2", 0L); + bindingsMap.put("d1", 1.1); + bindingsMap.put("d2", 0.0); + bindingsMap.put("s1", "true"); + bindingsMap.put("s2", "false"); + bindingsMap.put("b1", true); + bindingsMap.put("b2", false); + Expr.ObjectBinding bindings = InputBindings.withMap(bindingsMap); + + try { + ExpressionProcessing.initializeForStrictBooleansTests(true); + assertEquals(1L, eval("s1 && s1", bindings).value()); + assertEquals(0L, eval("s1 && s2", bindings).value()); + assertEquals(0L, eval("s2 && s1", bindings).value()); + assertEquals(0L, eval("s2 && s2", bindings).value()); + + assertEquals(1L, eval("s1 || s1", bindings).value()); + assertEquals(1L, eval("s1 || s2", bindings).value()); + assertEquals(1L, eval("s2 || s1", bindings).value()); + assertEquals(0L, eval("s2 || s2", bindings).value()); + + assertEquals(1L, eval("l1 && l1", bindings).value()); + assertEquals(0L, eval("l1 && l2", bindings).value()); + assertEquals(0L, eval("l2 && l1", bindings).value()); + assertEquals(0L, eval("l2 && l2", bindings).value()); + + assertEquals(1L, eval("b1 && b1", bindings).value()); + assertEquals(0L, eval("b1 && b2", bindings).value()); + assertEquals(0L, eval("b2 && b1", bindings).value()); + assertEquals(0L, eval("b2 && b2", bindings).value()); + + assertEquals(1L, eval("d1 && d1", bindings).value()); + assertEquals(0L, eval("d1 && d2", bindings).value()); + assertEquals(0L, eval("d2 && d1", bindings).value()); + assertEquals(0L, eval("d2 && d2", bindings).value()); + + assertEquals(1L, eval("b1", bindings).value()); + assertEquals(1L, eval("if(b1,1,0)", bindings).value()); + assertEquals(1L, eval("if(l1,1,0)", bindings).value()); + assertEquals(1L, eval("if(d1,1,0)", bindings).value()); + assertEquals(1L, eval("if(s1,1,0)", bindings).value()); + assertEquals(0L, eval("if(b2,1,0)", bindings).value()); + assertEquals(0L, eval("if(l2,1,0)", bindings).value()); + assertEquals(0L, eval("if(d2,1,0)", bindings).value()); + assertEquals(0L, eval("if(s2,1,0)", bindings).value()); + } + finally { + // reset + ExpressionProcessing.initializeForTests(null); + } + + try { + // turn on legacy insanity mode + ExpressionProcessing.initializeForStrictBooleansTests(false); + + assertEquals("true", eval("s1 && s1", bindings).value()); + assertEquals("false", eval("s1 && s2", bindings).value()); + assertEquals("false", eval("s2 && s1", bindings).value()); + assertEquals("false", eval("s2 && s2", bindings).value()); + + assertEquals("true", eval("b1 && b1", bindings).value()); + assertEquals("false", eval("b1 && b2", bindings).value()); + assertEquals("false", eval("b2 && b1", bindings).value()); + assertEquals("false", eval("b2 && b2", bindings).value()); + + assertEquals(100L, eval("l1 && l1", bindings).value()); + assertEquals(0L, eval("l1 && l2", bindings).value()); + assertEquals(0L, eval("l2 && l1", bindings).value()); + assertEquals(0L, eval("l2 && l2", bindings).value()); + + assertEquals(1.1, eval("d1 && d1", bindings).value()); + assertEquals(0.0, eval("d1 && d2", bindings).value()); + assertEquals(0.0, eval("d2 && d1", bindings).value()); + assertEquals(0.0, eval("d2 && d2", bindings).value()); + + assertEquals("true", eval("b1", bindings).value()); + assertEquals(1L, eval("if(b1,1,0)", bindings).value()); + assertEquals(1L, eval("if(l1,1,0)", bindings).value()); + assertEquals(1L, eval("if(d1,1,0)", bindings).value()); + assertEquals(1L, eval("if(s1,1,0)", bindings).value()); + assertEquals(0L, eval("if(b2,1,0)", bindings).value()); + assertEquals(0L, eval("if(l2,1,0)", bindings).value()); + assertEquals(0L, eval("if(d2,1,0)", bindings).value()); + assertEquals(0L, eval("if(s2,1,0)", bindings).value()); + } + finally { + // reset + ExpressionProcessing.initializeForTests(null); + } + } + + @Test + public void testEvalOfType() + { + // strings + ExprEval eval = ExprEval.ofType(ExpressionType.STRING, "stringy"); + Assert.assertEquals(ExpressionType.STRING, eval.type()); + Assert.assertEquals("stringy", eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING, 1L); + Assert.assertEquals(ExpressionType.STRING, eval.type()); + Assert.assertEquals("1", eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING, 1.0); + Assert.assertEquals(ExpressionType.STRING, eval.type()); + Assert.assertEquals("1.0", eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING, true); + Assert.assertEquals(ExpressionType.STRING, eval.type()); + Assert.assertEquals("true", eval.value()); + + // longs + eval = ExprEval.ofType(ExpressionType.LONG, 1L); + Assert.assertEquals(ExpressionType.LONG, eval.type()); + Assert.assertEquals(1L, eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG, 1.0); + Assert.assertEquals(ExpressionType.LONG, eval.type()); + Assert.assertEquals(1L, eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG, "1"); + Assert.assertEquals(ExpressionType.LONG, eval.type()); + Assert.assertEquals(1L, eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG, true); + Assert.assertEquals(ExpressionType.LONG, eval.type()); + Assert.assertEquals(1L, eval.value()); + + // doubles + eval = ExprEval.ofType(ExpressionType.DOUBLE, 1L); + Assert.assertEquals(ExpressionType.DOUBLE, eval.type()); + Assert.assertEquals(1.0, eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE, 1.0); + Assert.assertEquals(ExpressionType.DOUBLE, eval.type()); + Assert.assertEquals(1.0, eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE, "1"); + Assert.assertEquals(ExpressionType.DOUBLE, eval.type()); + Assert.assertEquals(1.0, eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE, true); + Assert.assertEquals(ExpressionType.DOUBLE, eval.type()); + Assert.assertEquals(1.0, eval.value()); + + // complex + TypeStrategiesTest.NullableLongPair pair = new TypeStrategiesTest.NullableLongPair(1L, 2L); + ExpressionType type = ExpressionType.fromColumnType(TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE); + + eval = ExprEval.ofType(type, pair); + Assert.assertEquals(type, eval.type()); + Assert.assertEquals(pair, eval.value()); + + ByteBuffer buffer = ByteBuffer.allocate(TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getStrategy().estimateSizeBytes(pair)); + TypeStrategiesTest.NULLABLE_TEST_PAIR_TYPE.getStrategy().write(buffer, pair, buffer.limit()); + byte[] pairBytes = buffer.array(); + eval = ExprEval.ofType(type, pairBytes); + Assert.assertEquals(type, eval.type()); + Assert.assertEquals(pair, eval.value()); + + eval = ExprEval.ofType(type, StringUtils.encodeBase64String(pairBytes)); + Assert.assertEquals(type, eval.type()); + Assert.assertEquals(pair, eval.value()); + + // arrays fall back to using 'bestEffortOf', but cast it to the expected output type + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {"1", "2", "3"}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2L, "3", true, false}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L, 1L, 0L}, (Object[]) eval.value()); + + // etc + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {"1", "2", "3"}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2L, "3", true, false}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0, 1.0, 0.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"1", "2", "3"}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {"1", "2", "3"}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {"1", "2", "3"}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {"1.0", "2.0", "3.0"}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {1.0, 2L, "3", true, false}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {"1.0", "2", "3", "true", "false"}, (Object[]) eval.value()); + } } diff --git a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java index fe5c2061260..499bcef08fe 100644 --- a/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java +++ b/processing/src/test/java/org/apache/druid/query/aggregation/ExpressionLambdaAggregatorFactoryTest.java @@ -254,7 +254,7 @@ public class ExpressionLambdaAggregatorFactoryTest extends InitializedNullHandli ImmutableSet.of("x"), null, "0", - "ARRAY[]", + "ARRAY[]", true, true, false, diff --git a/processing/src/test/java/org/apache/druid/segment/transform/TransformSpecTest.java b/processing/src/test/java/org/apache/druid/segment/transform/TransformSpecTest.java index 3537c4cef88..b9244b3ae10 100644 --- a/processing/src/test/java/org/apache/druid/segment/transform/TransformSpecTest.java +++ b/processing/src/test/java/org/apache/druid/segment/transform/TransformSpecTest.java @@ -30,6 +30,7 @@ import org.apache.druid.data.input.impl.MapInputRowParser; import org.apache.druid.data.input.impl.TimeAndDimsParseSpec; import org.apache.druid.data.input.impl.TimestampSpec; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.SelectorDimFilter; @@ -50,18 +51,20 @@ public class TransformSpecTest extends InitializedNullHandlingTest ); private static final Map ROW1 = ImmutableMap.builder() - .put("x", "foo") - .put("y", "bar") - .put("a", 2.0) - .put("b", 3L) - .build(); + .put("x", "foo") + .put("y", "bar") + .put("a", 2.0) + .put("b", 3L) + .put("bool", true) + .build(); private static final Map ROW2 = ImmutableMap.builder() - .put("x", "foo") - .put("y", "baz") - .put("a", 2.0) - .put("b", 4L) - .build(); + .put("x", "foo") + .put("y", "baz") + .put("a", 2.0) + .put("b", 4L) + .put("bool", false) + .build(); @Test public void testTransforms() @@ -202,6 +205,73 @@ public class TransformSpecTest extends InitializedNullHandlingTest Assert.assertEquals(DateTimes.of("2000-01-01T01:00:00Z").getMillis(), row.getTimestampFromEpoch()); } + @Test + public void testBoolTransforms() + { + try { + ExpressionProcessing.initializeForStrictBooleansTests(true); + final TransformSpec transformSpec = new TransformSpec( + null, + ImmutableList.of( + new ExpressionTransform("truthy1", "bool", TestExprMacroTable.INSTANCE), + new ExpressionTransform("truthy2", "if(bool,1,0)", TestExprMacroTable.INSTANCE) + ) + ); + + Assert.assertEquals( + ImmutableSet.of("bool"), + transformSpec.getRequiredColumns() + ); + + final InputRowParser> parser = transformSpec.decorate(PARSER); + final InputRow row = parser.parseBatch(ROW1).get(0); + + Assert.assertNotNull(row); + Assert.assertEquals(1L, row.getRaw("truthy1")); + Assert.assertEquals(1L, row.getRaw("truthy2")); + + final InputRow row2 = parser.parseBatch(ROW2).get(0); + + Assert.assertNotNull(row2); + Assert.assertEquals(0L, row2.getRaw("truthy1")); + Assert.assertEquals(0L, row2.getRaw("truthy2")); + } + finally { + ExpressionProcessing.initializeForTests(null); + } + try { + ExpressionProcessing.initializeForStrictBooleansTests(false); + final TransformSpec transformSpec = new TransformSpec( + null, + ImmutableList.of( + new ExpressionTransform("truthy1", "bool", TestExprMacroTable.INSTANCE), + new ExpressionTransform("truthy2", "if(bool,1,0)", TestExprMacroTable.INSTANCE) + ) + ); + + Assert.assertEquals( + ImmutableSet.of("bool"), + transformSpec.getRequiredColumns() + ); + + final InputRowParser> parser = transformSpec.decorate(PARSER); + final InputRow row = parser.parseBatch(ROW1).get(0); + + Assert.assertNotNull(row); + Assert.assertEquals("true", row.getRaw("truthy1")); + Assert.assertEquals(1L, row.getRaw("truthy2")); + + final InputRow row2 = parser.parseBatch(ROW2).get(0); + + Assert.assertNotNull(row2); + Assert.assertEquals("false", row2.getRaw("truthy1")); + Assert.assertEquals(0L, row2.getRaw("truthy2")); + } + finally { + ExpressionProcessing.initializeForTests(null); + } + } + @Test public void testSerde() throws Exception {