diff --git a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java index 1d5298ffc1d..dbfc11d3c4e 100644 --- a/core/src/main/java/org/apache/druid/math/expr/ExprEval.java +++ b/core/src/main/java/org/apache/druid/math/expr/ExprEval.java @@ -170,7 +170,7 @@ public abstract class ExprEval Object[] array = new Object[val.size()]; int i = 0; for (Object o : val) { - array[i++] = o == null ? null : ExprEval.ofType(ExpressionType.DOUBLE, o).value(); + array[i++] = ExprEval.ofType(ExpressionType.DOUBLE, o).value(); } return new NonnullPair<>(ExpressionType.DOUBLE_ARRAY, array); } else if (coercedType == Object.class) { @@ -193,7 +193,7 @@ public abstract class ExprEval if (eval != null) { array[i++] = eval.castTo(elementType).value(); } else { - array[i++] = null; + array[i++] = ExprEval.ofType(elementType, null).value(); } } ExpressionType arrayType = elementType == null @@ -329,7 +329,7 @@ public abstract class ExprEval } - public static ExprEval ofArray(ExpressionType outputType, Object[] value) + public static ExprEval ofArray(ExpressionType outputType, @Nullable Object[] value) { Preconditions.checkArgument(outputType.isArray(), "Output type %s is not an array", outputType); return new ArrayExprEval(outputType, value); @@ -398,7 +398,7 @@ public abstract class ExprEval final Long[] inputArray = (Long[]) val; final Object[] array = new Object[inputArray.length]; for (int i = 0; i < inputArray.length; i++) { - array[i] = inputArray[i]; + array[i] = inputArray[i] != null ? inputArray[i] : NullHandling.defaultLongValue(); } return new ArrayExprEval(ExpressionType.LONG_ARRAY, array); } @@ -414,7 +414,7 @@ public abstract class ExprEval final Integer[] inputArray = (Integer[]) val; final Object[] array = new Object[inputArray.length]; for (int i = 0; i < inputArray.length; i++) { - array[i] = inputArray[i] == null ? null : inputArray[i].longValue(); + array[i] = inputArray[i] != null ? inputArray[i].longValue() : NullHandling.defaultLongValue(); } return new ArrayExprEval(ExpressionType.LONG_ARRAY, array); } @@ -430,7 +430,7 @@ public abstract class ExprEval final Double[] inputArray = (Double[]) val; final Object[] array = new Object[inputArray.length]; for (int i = 0; i < inputArray.length; i++) { - array[i] = inputArray[i]; + array[i] = inputArray[i] != null ? inputArray[i] : NullHandling.defaultDoubleValue(); } return new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, array); } @@ -446,7 +446,7 @@ public abstract class ExprEval final Float[] inputArray = (Float[]) val; final Object[] array = new Object[inputArray.length]; for (int i = 0; i < inputArray.length; i++) { - array[i] = inputArray[i] != null ? inputArray[i].doubleValue() : null; + array[i] = inputArray[i] != null ? inputArray[i].doubleValue() : NullHandling.defaultDoubleValue(); } return new ArrayExprEval(ExpressionType.DOUBLE_ARRAY, array); } @@ -568,9 +568,28 @@ public abstract class ExprEval return ofComplex(type, value); case ARRAY: - // nested arrays, here be dragons... don't do any fancy coercion, assume everything is already sane types... - if (type.getElementType().isArray()) { - return ofArray(type, (Object[]) value); + ExpressionType elementType = (ExpressionType) type.getElementType(); + if (value == null) { + return ofArray(type, null); + } + if (value instanceof List) { + List theList = (List) value; + Object[] array = new Object[theList.size()]; + int i = 0; + for (Object o : theList) { + array[i++] = ExprEval.ofType(elementType, o).value(); + } + return ofArray(type, array); + } + + if (value instanceof Object[]) { + Object[] inputArray = (Object[]) value; + Object[] array = new Object[inputArray.length]; + int i = 0; + for (Object o : inputArray) { + array[i++] = ExprEval.ofType(elementType, o).value(); + } + return ofArray(type, array); } // in a better world, we might get an object that matches the type signature for arrays and could do a switch // statement here, but this is not that world yet, and things that are array typed might also be non-arrays, @@ -798,7 +817,7 @@ public abstract class ExprEval return ExprEval.ofStringArray(value == null ? null : new Object[] {value.toString()}); } } - throw new IAE("invalid type " + castTo); + throw new IAE("invalid type cannot cast " + type() + " to " + castTo); } @Override @@ -863,7 +882,7 @@ public abstract class ExprEval return ExprEval.ofStringArray(value == null ? null : new Object[] {value.toString()}); } } - throw new IAE("invalid type " + castTo); + throw new IAE("invalid type cannot cast " + type() + " to " + castTo); } @Override @@ -1032,7 +1051,7 @@ public abstract class ExprEval return ExprEval.ofStringArray(value == null ? null : new Object[] {value}); } } - throw new IAE("invalid type " + castTo); + throw new IAE("invalid type cannot cast " + type() + " to " + castTo); } @Override @@ -1221,7 +1240,7 @@ public abstract class ExprEval return ExprEval.ofArray(castTo, cast); } - throw new IAE("invalid type " + castTo); + throw new IAE("invalid type cannot cast " + type() + " to " + castTo); } @Override @@ -1308,7 +1327,11 @@ public abstract class ExprEval if (expressionType.equals(castTo)) { return this; } - throw new IAE("invalid type " + castTo); + // allow cast of unknown complex to some other complex type + if (expressionType.getComplexTypeName() == null) { + return new ComplexExprEval(castTo, value); + } + throw new IAE("invalid type cannot cast " + expressionType + " to " + castTo); } @Override diff --git a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java index f1d3bee5b95..c3322ba19f7 100644 --- a/core/src/test/java/org/apache/druid/math/expr/EvalTest.java +++ b/core/src/test/java/org/apache/druid/math/expr/EvalTest.java @@ -29,12 +29,12 @@ import org.apache.druid.segment.column.TypeStrategiesTest; import org.apache.druid.testing.InitializedNullHandlingTest; import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; @@ -54,9 +54,6 @@ public class EvalTest extends InitializedNullHandlingTest ); } - @Rule - public ExpectedException expectedException = ExpectedException.none(); - private long evalLong(String x, Expr.ObjectBinding bindings) { ExprEval ret = eval(x, bindings); @@ -270,73 +267,91 @@ public class EvalTest extends InitializedNullHandlingTest @Test public void testStringArrayToScalarStringBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type STRING"); - ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.STRING); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.STRING) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to STRING", t.getMessage()); } @Test public void testStringArrayToScalarLongBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type LONG"); - ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.LONG); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.LONG) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to LONG", t.getMessage()); } @Test public void testStringArrayToScalarDoubleBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type DOUBLE"); - ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.DOUBLE); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofStringArray(new String[]{"foo", "bar"}).castTo(ExpressionType.DOUBLE) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to DOUBLE", t.getMessage()); } @Test public void testLongArrayToScalarStringBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type STRING"); - ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.STRING); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.STRING) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to STRING", t.getMessage()); } @Test public void testLongArrayToScalarLongBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type LONG"); - ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.LONG); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.LONG) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to LONG", t.getMessage()); } @Test public void testLongArrayToScalarDoubleBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type DOUBLE"); - ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.DOUBLE); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofLongArray(new Long[]{1L, 2L}).castTo(ExpressionType.DOUBLE) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to DOUBLE", t.getMessage()); } @Test public void testDoubleArrayToScalarStringBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type STRING"); - ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.STRING); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.STRING) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to STRING", t.getMessage()); } @Test public void testDoubleArrayToScalarLongBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type LONG"); - ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.LONG); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.LONG) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to LONG", t.getMessage()); } @Test public void testDoubleArrayToScalarDoubleBadCast() { - expectedException.expect(IAE.class); - expectedException.expectMessage("invalid type DOUBLE"); - ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.DOUBLE); + Throwable t = Assert.assertThrows( + IAE.class, + () -> ExprEval.ofDoubleArray(new Double[]{1.1, 2.2}).castTo(ExpressionType.DOUBLE) + ); + Assert.assertEquals("invalid type cannot cast ARRAY to DOUBLE", t.getMessage()); } @Test @@ -712,6 +727,19 @@ public class EvalTest extends InitializedNullHandlingTest Assert.assertEquals(ExpressionType.STRING, eval.type()); Assert.assertEquals("true", eval.value()); + // strings might also be liars and arrays or lists + eval = ExprEval.ofType(ExpressionType.STRING, new Object[]{"a", "b", "c"}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[]{"a", "b", "c"}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING, new String[]{"a", "b", "c"}); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[]{"a", "b", "c"}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.STRING, Arrays.asList("a", "b", "c")); + Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[]{"a", "b", "c"}, (Object[]) eval.value()); + // longs eval = ExprEval.ofType(ExpressionType.LONG, 1L); Assert.assertEquals(ExpressionType.LONG, eval.type()); @@ -765,40 +793,124 @@ public class EvalTest extends InitializedNullHandlingTest Assert.assertEquals(type, eval.type()); Assert.assertEquals(pair, eval.value()); - // arrays fall back to using 'bestEffortOf', but cast it to the expected output type + // json type isn't defined in druid-core + ExpressionType json = ExpressionType.fromString("COMPLEX"); + eval = ExprEval.ofType(json, ImmutableMap.of("x", 1L, "y", 2L)); + Assert.assertEquals(json, eval.type()); + Assert.assertEquals(ImmutableMap.of("x", 1L, "y", 2L), eval.value()); + + eval = ExprEval.ofType(json, "hello"); + Assert.assertEquals(json, eval.type()); + Assert.assertEquals("hello", eval.value()); + + ExpressionType stringyComplexThing = ExpressionType.fromString("COMPLEX"); + eval = ExprEval.ofType(stringyComplexThing, "notbase64"); + Assert.assertEquals(stringyComplexThing, eval.type()); + Assert.assertEquals("notbase64", eval.value()); + + // arrays + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, ImmutableList.of(1L, 2L, 3L)); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Long[]{1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new long[]{1L, 2L, 3L}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new int[]{1, 2, 3}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, null, 3L}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, NullHandling.defaultLongValue(), 3L}, (Object[]) eval.value()); + + // arrays might have to fall back to using 'bestEffortOf', but will cast it to the expected output type eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {"1", "2", "3"}); Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); - eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1L, 2L, 3L}); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new String[] {"1", "2", "3"}); Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {"1", "2", "wat", "3"}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, NullHandling.defaultLongValue(), 3L}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2.0, 3.0}); Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new double[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2.0, null, 3.0}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, NullHandling.defaultLongValue(), 3L}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new Object[] {1.0, 2L, "3", true, false}); Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L, 1L, 0L}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.LONG_ARRAY, new float[] {1.0f, 2.0f, 3.0f}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + // etc + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Double[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new double[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {"1", "2", "3"}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {"1", "2", "wat", "3"}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, NullHandling.defaultDoubleValue(), 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1L, 2L, 3L}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); - eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2.0, 3.0}); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new long[] {1L, 2L, 3L}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1L, 2L, null, 3L}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, NullHandling.defaultDoubleValue(), 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Object[] {1.0, 2L, "3", true, false}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0, 1.0, 0.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new Float[] {1.0f, 2.0f, 3.0f}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.ofType(ExpressionType.DOUBLE_ARRAY, new float[] {1.0f, 2.0f, 3.0f}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.ofType(ExpressionType.STRING_ARRAY, new Object[] {"1", "2", "3"}); Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {"1", "2", "3"}, (Object[]) eval.value()); @@ -815,20 +927,104 @@ public class EvalTest extends InitializedNullHandlingTest Assert.assertEquals(ExpressionType.STRING_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {"1.0", "2", "3", "true", "false"}, (Object[]) eval.value()); - // json type isn't defined in druid-core - ExpressionType json = ExpressionType.fromString("COMPLEX"); - eval = ExprEval.ofType(json, ImmutableMap.of("x", 1L, "y", 2L)); - Assert.assertEquals(json, eval.type()); - Assert.assertEquals(ImmutableMap.of("x", 1L, "y", 2L), eval.value()); + // nested arrays + try { + ExpressionProcessing.initializeForTests(true); - eval = ExprEval.ofType(json, "hello"); - Assert.assertEquals(json, eval.type()); - Assert.assertEquals("hello", eval.value()); + ExpressionType nestedLongArray = ExpressionTypeFactory.getInstance().ofArray(ExpressionType.LONG_ARRAY); + final Object[] expectedLongArray = new Object[]{ + new Object[] {1L, 2L, 3L}, + new Object[] {5L, NullHandling.defaultLongValue(), 9L}, + null, + new Object[] {2L, 4L, 6L} + }; - ExpressionType stringyComplexThing = ExpressionType.fromString("COMPLEX"); - eval = ExprEval.ofType(stringyComplexThing, "notbase64"); - Assert.assertEquals(stringyComplexThing, eval.type()); - Assert.assertEquals("notbase64", eval.value()); + List longArrayInputs = Arrays.asList( + new Object[]{ + new Object[] {1L, 2L, 3L}, + new Object[] {5L, null, 9L}, + null, + new Object[] {2L, 4L, 6L} + }, + Arrays.asList( + new Object[] {1L, 2L, 3L}, + new Object[] {5L, null, 9L}, + null, + new Object[] {2L, 4L, 6L} + ), + Arrays.asList( + Arrays.asList(1L, 2L, 3L), + Arrays.asList(5L, null, 9L), + null, + Arrays.asList(2L, 4L, 6L) + ), + Arrays.asList( + Arrays.asList(1L, 2L, 3L), + Arrays.asList("5", "hello", "9"), + null, + new Object[]{2.2, 4.4, 6.6} + ) + ); + + for (Object o : longArrayInputs) { + eval = ExprEval.ofType(nestedLongArray, o); + Assert.assertEquals(nestedLongArray, eval.type()); + Object[] val = (Object[]) eval.value(); + Assert.assertEquals(expectedLongArray.length, val.length); + for (int i = 0; i < expectedLongArray.length; i++) { + Assert.assertArrayEquals((Object[]) expectedLongArray[i], (Object[]) val[i]); + } + } + + ExpressionType nestedDoubleArray = ExpressionTypeFactory.getInstance().ofArray(ExpressionType.DOUBLE_ARRAY); + final Object[] expectedDoubleArray = new Object[]{ + new Object[] {1.1, 2.2, 3.3}, + new Object[] {5.5, NullHandling.defaultDoubleValue(), 9.9}, + null, + new Object[] {2.2, 4.4, 6.6} + }; + + List doubleArrayInputs = Arrays.asList( + new Object[]{ + new Object[] {1.1, 2.2, 3.3}, + new Object[] {5.5, null, 9.9}, + null, + new Object[] {2.2, 4.4, 6.6} + }, + new Object[]{ + Arrays.asList(1.1, 2.2, 3.3), + Arrays.asList(5.5, null, 9.9), + null, + Arrays.asList(2.2, 4.4, 6.6) + }, + Arrays.asList( + Arrays.asList(1.1, 2.2, 3.3), + Arrays.asList(5.5, null, 9.9), + null, + Arrays.asList(2.2, 4.4, 6.6) + ), + new Object[]{ + new Object[] {"1.1", "2.2", "3.3"}, + Arrays.asList("5.5", null, "9.9"), + null, + new String[] {"2.2", "4.4", "6.6"} + } + ); + + for (Object o : doubleArrayInputs) { + eval = ExprEval.ofType(nestedDoubleArray, o); + Assert.assertEquals(nestedDoubleArray, eval.type()); + Object[] val = (Object[]) eval.value(); + Assert.assertEquals(expectedLongArray.length, val.length); + for (int i = 0; i < expectedLongArray.length; i++) { + Assert.assertArrayEquals((Object[]) expectedDoubleArray[i], (Object[]) val[i]); + } + } + } + finally { + // reset + ExpressionProcessing.initializeForTests(null); + } } @Test @@ -899,6 +1095,10 @@ public class EvalTest extends InitializedNullHandlingTest Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + eval = ExprEval.bestEffortOf(new Integer[] {1, 2, 3}); + Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); + eval = ExprEval.bestEffortOf(new int[] {1, 2, 3}); Assert.assertEquals(ExpressionType.LONG_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1L, 2L, 3L}, (Object[]) eval.value()); @@ -909,7 +1109,11 @@ public class EvalTest extends InitializedNullHandlingTest eval = ExprEval.bestEffortOf(new Object[] {null, 1.0, 2.0, 3.0}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); - Assert.assertArrayEquals(new Object[] {null, 1.0, 2.0, 3.0}, (Object[]) eval.value()); + Assert.assertArrayEquals(new Object[] {NullHandling.defaultDoubleValue(), 1.0, 2.0, 3.0}, (Object[]) eval.value()); + + eval = ExprEval.bestEffortOf(new Double[] {1.0, 2.0, 3.0}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); eval = ExprEval.bestEffortOf(new double[] {1.0, 2.0, 3.0}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); @@ -919,6 +1123,10 @@ public class EvalTest extends InitializedNullHandlingTest Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.bestEffortOf(new Float[] {1.0f, 2.0f, 3.0f}); + Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); + Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); + eval = ExprEval.bestEffortOf(new float[] {1.0f, 2.0f, 3.0f}); Assert.assertEquals(ExpressionType.DOUBLE_ARRAY, eval.type()); Assert.assertArrayEquals(new Object[] {1.0, 2.0, 3.0}, (Object[]) eval.value()); diff --git a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java index 93b992a4f42..d40ee5dee8e 100644 --- a/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java +++ b/processing/src/main/java/org/apache/druid/segment/column/ObjectStrategyComplexTypeStrategy.java @@ -46,7 +46,7 @@ public class ObjectStrategyComplexTypeStrategy implements TypeStrategy public int estimateSizeBytes(@Nullable T value) { byte[] bytes = objectStrategy.toBytes(value); - return bytes == null ? 0 : bytes.length; + return Integer.BYTES + (bytes == null ? 0 : bytes.length); } @Override @@ -56,7 +56,9 @@ public class ObjectStrategyComplexTypeStrategy implements TypeStrategy ByteBuffer dupe = buffer.duplicate(); dupe.order(buffer.order()); dupe.limit(dupe.position() + complexLength); - return objectStrategy.fromByteBuffer(dupe, complexLength); + T value = objectStrategy.fromByteBuffer(dupe, complexLength); + buffer.position(buffer.position() + complexLength); + return value; } @Override diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java index edbf2b8da5d..c3bf402df9e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/aggregation/builtin/ArraySqlAggregator.java @@ -48,9 +48,7 @@ import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.Expressions; import org.apache.druid.sql.calcite.planner.Calcites; import org.apache.druid.sql.calcite.planner.PlannerContext; -import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException; import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry; -import org.apache.druid.sql.calcite.table.RowSignatures; import javax.annotation.Nullable; import java.util.List; @@ -167,9 +165,6 @@ public class ArraySqlAggregator implements SqlAggregator public RelDataType inferReturnType(SqlOperatorBinding sqlOperatorBinding) { RelDataType type = sqlOperatorBinding.getOperandType(0); - if (type instanceof RowSignatures.ComplexSqlType) { - throw new UnsupportedSQLQueryException("Cannot use ARRAY_AGG on complex inputs %s", type); - } return sqlOperatorBinding.getTypeFactory().createArrayType( type, -1 diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOperatorConversion.java b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOperatorConversion.java index f2592d574d3..b758c2a8972 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOperatorConversion.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/expression/builtin/ArrayOrdinalOperatorConversion.java @@ -62,7 +62,7 @@ public class ArrayOrdinalOperatorConversion extends DirectOperatorConversion { RelDataType type = sqlOperatorBinding.getOperandType(0); if (SqlTypeUtil.isArray(type)) { - type.getComponentType(); + return type.getComponentType(); } return type; } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java index ead3b506e3c..119cae8634e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteArraysQueryTest.java @@ -236,10 +236,10 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest "[\"a\",\"a\",\"b\"]", "[7,0]", "[1.0,0.0]", - "7", - "1.0", - "7", - "1.0" + 7L, + 1.0, + 7L, + 1.0 } ); } else { @@ -264,10 +264,10 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest "[\"a\",\"a\",\"b\"]", "[7,null]", "[1.0,null]", - "7", - "1.0", - "7", - "1.0" + 7L, + 1.0, + 7L, + 1.0 } ); } @@ -312,10 +312,10 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest "array_concat(array(\"d1\"),array(\"d2\"))", ColumnType.DOUBLE_ARRAY ), - expressionVirtualColumn("v12", "array_offset(array(\"l1\"),0)", ColumnType.LONG_ARRAY), - expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.DOUBLE_ARRAY), - expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.LONG_ARRAY), - expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.DOUBLE_ARRAY), + expressionVirtualColumn("v12", "array_offset(array(\"l1\"),0)", ColumnType.LONG), + expressionVirtualColumn("v13", "array_offset(array(\"d1\"),0)", ColumnType.DOUBLE), + expressionVirtualColumn("v14", "array_ordinal(array(\"l1\"),1)", ColumnType.LONG), + expressionVirtualColumn("v15", "array_ordinal(array(\"d1\"),1)", ColumnType.DOUBLE), expressionVirtualColumn("v2", "array(1.9,2.2,4.3)", ColumnType.DOUBLE_ARRAY), expressionVirtualColumn("v3", "array_append(\"dim3\",'foo')", ColumnType.STRING_ARRAY), expressionVirtualColumn("v4", "array_prepend('foo',array(\"dim2\"))", ColumnType.STRING_ARRAY), @@ -355,7 +355,32 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest .context(QUERY_CONTEXT_DEFAULT) .build() ), - expectedResults + expectedResults, + RowSignature.builder() + .add("dim1", ColumnType.STRING) + .add("dim2", ColumnType.STRING) + .add("dim3", ColumnType.STRING) + .add("l1", ColumnType.LONG) + .add("l2", ColumnType.LONG) + .add("d1", ColumnType.DOUBLE) + .add("d2", ColumnType.DOUBLE) + .add("EXPR$7", ColumnType.STRING_ARRAY) + .add("EXPR$8", ColumnType.LONG_ARRAY) + .add("EXPR$9", ColumnType.DOUBLE_ARRAY) + .add("EXPR$10", ColumnType.STRING_ARRAY) + .add("EXPR$11", ColumnType.STRING_ARRAY) + .add("EXPR$12", ColumnType.LONG_ARRAY) + .add("EXPR$13", ColumnType.LONG_ARRAY) + .add("EXPR$14", ColumnType.DOUBLE_ARRAY) + .add("EXPR$15", ColumnType.DOUBLE_ARRAY) + .add("EXPR$16", ColumnType.STRING_ARRAY) + .add("EXPR$17", ColumnType.LONG_ARRAY) + .add("EXPR$18", ColumnType.DOUBLE_ARRAY) + .add("EXPR$19", ColumnType.LONG) + .add("EXPR$20", ColumnType.DOUBLE) + .add("EXPR$21", ColumnType.LONG) + .add("EXPR$22", ColumnType.DOUBLE) + .build() ); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java index f11a7c2717d..36464032e36 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteNestedDataQueryTest.java @@ -33,12 +33,14 @@ import org.apache.druid.data.input.impl.StringDimensionSchema; import org.apache.druid.data.input.impl.TimestampSpec; import org.apache.druid.guice.DruidInjectorBuilder; import org.apache.druid.guice.NestedDataModule; +import org.apache.druid.java.util.common.HumanReadableBytes; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.query.Druids; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; @@ -857,6 +859,85 @@ public class CalciteNestedDataQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testJsonAndArrayAgg() + { + cannotVectorize(); + testQuery( + "SELECT " + + "string, " + + "ARRAY_AGG(nest, 16384), " + + "SUM(cnt) " + + "FROM druid.nested GROUP BY 1", + ImmutableList.of( + GroupByQuery.builder() + .setDataSource(DATA_SOURCE) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + dimensions( + new DefaultDimensionSpec("string", "d0") + ) + ) + .setAggregatorSpecs( + aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("nest"), + "__acc", + "ARRAY>[]", + "ARRAY>[]", + true, + true, + false, + "array_append(\"__acc\", \"nest\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + HumanReadableBytes.valueOf(16384), + queryFramework().macroTable() + ), + new LongSumAggregatorFactory("a1", "cnt") + ) + ) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{ + "aaa", + "[{\"x\":100,\"y\":2.02,\"z\":\"300\",\"mixed\":1,\"mixed2\":\"1\"},{\"x\":100,\"y\":2.02,\"z\":\"400\",\"mixed2\":1.1}]", + 2L + }, + new Object[]{ + "bbb", + "[null]", + 1L + }, + new Object[]{ + "ccc", + "[{\"x\":200,\"y\":3.03,\"z\":\"abcdef\",\"mixed\":1.1,\"mixed2\":1}]", + 1L + }, + new Object[]{ + "ddd", + "[null,null]", + 2L + }, + new Object[]{ + "eee", + "[null]", + 1L + } + ), + RowSignature.builder() + .add("string", ColumnType.STRING) + .add("EXPR$1", ColumnType.ofArray(NestedDataComplexTypeSerde.TYPE)) + .add("EXPR$2", ColumnType.LONG) + .build() + ); + } + @Test public void testGroupByPathSelectorFilterLong() { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 512b7be402e..6da27712633 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -5606,18 +5606,40 @@ public class CalciteQueryTest extends BaseCalciteQueryTest @Test public void testArrayAggQueryOnComplexDatatypes() { + cannotVectorize(); msqCompatible(); - try { - testQuery("SELECT ARRAY_AGG(unique_dim1) FROM druid.foo", ImmutableList.of(), ImmutableList.of()); - Assert.fail("query execution should fail"); - } - catch (SqlPlanningException e) { - Assert.assertTrue( - e.getMessage().contains("Cannot use ARRAY_AGG on complex inputs COMPLEX") - ); - Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), e.getErrorCode()); - Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorClass(), e.getErrorClass()); - } + testQuery( + "SELECT ARRAY_AGG(unique_dim1) FROM druid.foo", + ImmutableList.of( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .aggregators(aggregators( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("unique_dim1"), + "__acc", + "ARRAY>[]", + "ARRAY>[]", + true, + true, + false, + "array_append(\"__acc\", \"unique_dim1\")", + "array_concat(\"__acc\", \"a0\")", + null, + null, + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + queryFramework().macroTable() + ) + )) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ), + ImmutableList.of( + new Object[]{"[\"AQAAAEAAAA==\",\"AQAAAQAAAAHNBA==\",\"AQAAAQAAAAOzAg==\",\"AQAAAQAAAAFREA==\",\"AQAAAQAAAACyEA==\",\"AQAAAQAAAAEkAQ==\"]"} + ) + ); } @Test