From d7c9c2f3671d1e9cefff463eb5cb557e9e882f38 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Sun, 25 Jun 2023 09:35:18 -0700 Subject: [PATCH] SqlResults: Coerce arrays to lists for VARCHAR. (#14260) * SqlResults: Coerce arrays to lists for VARCHAR. Useful for STRING_TO_MV, which returns VARCHAR at the SQL layer and an ExprEval with String[] at the native layer. * Fix style. * Improve test coverage. * Remove unnecessary throws. --- .../druid/sql/calcite/run/SqlResults.java | 100 +++++---- .../CalciteMultiValueStringQueryTest.java | 196 ++++++++++++++++- .../druid/sql/calcite/run/SqlResultsTest.java | 206 ++++++++++++++++-- 3 files changed, 430 insertions(+), 72 deletions(-) diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlResults.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlResults.java index d48eaa8742a..ab400b97a3e 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlResults.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlResults.java @@ -39,7 +39,7 @@ import org.apache.druid.sql.calcite.planner.PlannerContext; import org.joda.time.DateTime; import org.joda.time.DateTimeZone; -import java.io.IOException; +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -70,21 +70,23 @@ public class SqlResults coercedValue = String.valueOf(value); } else if (value instanceof Boolean) { coercedValue = String.valueOf(value); - } else if (value instanceof Collection) { - // Iterate through the collection, coercing each value. Useful for handling selects of multi-value dimensions. - final List valueStrings = - ((Collection) value).stream() - .map(v -> (String) coerce(jsonMapper, context, v, sqlTypeName)) - .collect(Collectors.toList()); - - try { - coercedValue = jsonMapper.writeValueAsString(valueStrings); - } - catch (IOException e) { - throw new RuntimeException(e); - } } else { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + final Object maybeList = maybeCoerceArrayToList(value, false); + + // Check if "maybeList" was originally a Collection of some kind, or was able to be coerced to one. + // Then Iterate through the collection, coercing each value. Useful for handling multi-value dimensions. + if (maybeList instanceof Collection) { + final List valueStrings = + ((Collection) maybeList) + .stream() + .map(v -> (String) coerce(jsonMapper, context, v, sqlTypeName)) + .collect(Collectors.toList()); + + // Must stringify since the caller is expecting CHAR_TYPES. + coercedValue = coerceUsingObjectMapper(jsonMapper, valueStrings, sqlTypeName); + } else { + throw cannotCoerce(value, sqlTypeName); + } } } else if (value == null) { coercedValue = null; @@ -93,12 +95,14 @@ public class SqlResults } else if (sqlTypeName == SqlTypeName.TIMESTAMP) { return Calcites.jodaToCalciteTimestamp(coerceDateTime(value, sqlTypeName), context.getTimeZone()); } else if (sqlTypeName == SqlTypeName.BOOLEAN) { - if (value instanceof String) { + if (value instanceof Boolean) { + coercedValue = value; + } else if (value instanceof String) { coercedValue = Evals.asBoolean(((String) value)); } else if (value instanceof Number) { coercedValue = Evals.asBoolean(((Number) value).longValue()); } else { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } else if (sqlTypeName == SqlTypeName.INTEGER) { if (value instanceof String) { @@ -106,38 +110,33 @@ public class SqlResults } else if (value instanceof Number) { coercedValue = ((Number) value).intValue(); } else { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } else if (sqlTypeName == SqlTypeName.BIGINT) { try { coercedValue = DimensionHandlerUtils.convertObjectToLong(value); } catch (Exception e) { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } else if (sqlTypeName == SqlTypeName.FLOAT) { try { coercedValue = DimensionHandlerUtils.convertObjectToFloat(value); } catch (Exception e) { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } else if (SqlTypeName.FRACTIONAL_TYPES.contains(sqlTypeName)) { try { coercedValue = DimensionHandlerUtils.convertObjectToDouble(value); } catch (Exception e) { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } else if (sqlTypeName == SqlTypeName.OTHER) { // Complex type, try to serialize if we should, else print class name if (context.isSerializeComplexValues()) { - try { - coercedValue = jsonMapper.writeValueAsString(value); - } - catch (JsonProcessingException jex) { - throw new ISE(jex, "Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); - } + coercedValue = coerceUsingObjectMapper(jsonMapper, value, sqlTypeName); } else { coercedValue = value.getClass().getName(); } @@ -148,12 +147,7 @@ public class SqlResults } else if (value instanceof NlsString) { coercedValue = ((NlsString) value).getValue(); } else { - try { - coercedValue = jsonMapper.writeValueAsString(value); - } - catch (IOException e) { - throw new RuntimeException(e); - } + coercedValue = coerceUsingObjectMapper(jsonMapper, value, sqlTypeName); } } else { // the protobuf jdbc handler prefers lists (it actually can't handle java arrays as sql arrays, only java lists) @@ -161,18 +155,22 @@ public class SqlResults // here if needed coercedValue = maybeCoerceArrayToList(value, true); if (coercedValue == null) { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } } } else { - throw new ISE("Cannot coerce [%s] to %s", value.getClass().getName(), sqlTypeName); + throw cannotCoerce(value, sqlTypeName); } return coercedValue; } - + /** + * Attempt to coerce a value to {@link List}. If it cannot be coerced, either return the original value (if mustCoerce + * is false) or return null (if mustCoerce is true). + */ @VisibleForTesting + @Nullable static Object maybeCoerceArrayToList(Object value, boolean mustCoerce) { if (value instanceof List) { @@ -222,11 +220,39 @@ public class SqlResults } else if (value instanceof DateTime) { dateTime = (DateTime) value; } else { - throw new ISE("Cannot coerce[%s] to %s", value.getClass().getName(), sqlType); + throw cannotCoerce(value, sqlType); } return dateTime; } + private static String coerceUsingObjectMapper( + final ObjectMapper jsonMapper, + final Object value, + final SqlTypeName sqlTypeName + ) + { + try { + return jsonMapper.writeValueAsString(value); + } + catch (JsonProcessingException e) { + throw cannotCoerce(e, value, sqlTypeName); + } + } + + private static IllegalStateException cannotCoerce( + final Throwable t, + final Object value, + final SqlTypeName sqlTypeName + ) + { + return new ISE(t, "Cannot coerce [%s] to [%s]", value == null ? "null" : value.getClass().getName(), sqlTypeName); + } + + private static IllegalStateException cannotCoerce(final Object value, final SqlTypeName sqlTypeName) + { + return cannotCoerce(null, value, sqlTypeName); + } + /** * Context for {@link #coerce(ObjectMapper, Context, Object, SqlTypeName)} */ diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java index d016f179532..f17ef78c175 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteMultiValueStringQueryTest.java @@ -27,8 +27,12 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.math.expr.ExpressionProcessing; import org.apache.druid.query.Druids; +import org.apache.druid.query.aggregation.CountAggregatorFactory; +import org.apache.druid.query.aggregation.ExpressionLambdaAggregatorFactory; +import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.dimension.DefaultDimensionSpec; +import org.apache.druid.query.expression.TestExprMacroTable; import org.apache.druid.query.filter.AndDimFilter; import org.apache.druid.query.filter.ExpressionDimFilter; import org.apache.druid.query.filter.InDimFilter; @@ -1010,7 +1014,7 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ); } testQuery( - "SELECT STRING_TO_MV(CONCAT(MV_TO_STRING(dim3, ','), ',d'), ','), SUM(cnt) FROM druid.numfoo WHERE MV_LENGTH(dim3) > 0 GROUP BY 1 ORDER BY 2 DESC", + "SELECT STRING_TO_MV(CONCAT(MV_TO_STRING(dim3, ','), ',d'), ','), SUM(cnt) FROM druid.numfoo WHERE MV_LENGTH(dim3) > 0 GROUP BY 1 ORDER BY 2 DESC, 1", ImmutableList.of( GroupByQuery.builder() .setDataSource(CalciteTests.DATASOURCE3) @@ -1025,18 +1029,21 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ) ) .setDimFilter(bound("v0", "0", null, true, false, null, StringComparators.NUMERIC)) - .setDimensions( - dimensions( - new DefaultDimensionSpec("v1", "_d0", ColumnType.STRING) - ) - ) + .setDimensions(dimensions(new DefaultDimensionSpec("v1", "_d0", ColumnType.STRING))) .setAggregatorSpecs(aggregators(new LongSumAggregatorFactory("a0", "cnt"))) .setLimitSpec(new DefaultLimitSpec( - ImmutableList.of(new OrderByColumnSpec( - "a0", - OrderByColumnSpec.Direction.DESCENDING, - StringComparators.NUMERIC - )), + ImmutableList.of( + new OrderByColumnSpec( + "a0", + OrderByColumnSpec.Direction.DESCENDING, + StringComparators.NUMERIC + ), + new OrderByColumnSpec( + "_d0", + OrderByColumnSpec.Direction.ASCENDING, + StringComparators.LEXICOGRAPHIC + ) + ), Integer.MAX_VALUE )) .setContext(QUERY_CONTEXT_DEFAULT) @@ -1046,6 +1053,173 @@ public class CalciteMultiValueStringQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testSelectAndFilterByStringToMV() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testBuilder() + .sql("SELECT STRING_TO_MV(CONCAT(MV_TO_STRING(dim3, ','), ',d'), ',') FROM druid.numfoo " + + "WHERE MV_CONTAINS(STRING_TO_MV(CONCAT(MV_TO_STRING(dim3, ','), ',d'), ','), 'd')") + .expectedQuery( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .virtualColumns( + expressionVirtualColumn( + "v0", + "string_to_array(concat(array_to_string(\"dim3\",','),',d'),',')", + ColumnType.STRING + ) + ) + .filters(expressionFilter( + "array_contains(string_to_array(concat(array_to_string(\"dim3\",','),',d'),','),'d')")) + .columns("v0") + .context(QUERY_CONTEXT_DEFAULT) + .build() + ) + .expectedResults( + NullHandling.sqlCompatible() ? + ImmutableList.of( + new Object[]{"[\"a\",\"b\",\"d\"]"}, + new Object[]{"[\"b\",\"c\",\"d\"]"}, + new Object[]{"[\"d\",\"d\"]"}, + new Object[]{"[\"\",\"d\"]"} + ) : ImmutableList.of( + new Object[]{"[\"a\",\"b\",\"d\"]"}, + new Object[]{"[\"b\",\"c\",\"d\"]"}, + new Object[]{"[\"d\",\"d\"]"}, + new Object[]{"[\"\",\"d\"]"}, + new Object[]{"[\"\",\"d\"]"}, + new Object[]{"[\"\",\"d\"]"} + ) + ) + .run(); + } + + @Test + public void testStringToMVOfConstant() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testBuilder() + .sql("SELECT m1, STRING_TO_MV('a,b', ',') AS mv FROM druid.numfoo GROUP BY 1") + .expectedQuery( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setPostAggregatorSpecs(ImmutableList.of(expressionPostAgg("p0", "string_to_array('a,b',',')"))) + .setDimensions(dimensions(new DefaultDimensionSpec("m1", "_d0", ColumnType.FLOAT))) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ) + .expectedResults( + ImmutableList.of( + new Object[]{1.0f, "[\"a\",\"b\"]"}, + new Object[]{2.0f, "[\"a\",\"b\"]"}, + new Object[]{3.0f, "[\"a\",\"b\"]"}, + new Object[]{4.0f, "[\"a\",\"b\"]"}, + new Object[]{5.0f, "[\"a\",\"b\"]"}, + new Object[]{6.0f, "[\"a\",\"b\"]"} + ) + ) + .run(); + } + + @Test + public void testStringToMVOfConstantGroupedBy() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + testBuilder() + .sql("SELECT m1, STRING_TO_MV('a,b', ',') AS mv FROM druid.numfoo GROUP BY 1, 2") + .expectedQuery( + GroupByQuery.builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setVirtualColumns( + expressionVirtualColumn("v0", "string_to_array('a,b',',')", ColumnType.STRING) + ) + .setDimensions(dimensions( + new DefaultDimensionSpec("m1", "_d0", ColumnType.FLOAT), + new DefaultDimensionSpec("v0", "_d1", ColumnType.STRING) + )) + .setContext(QUERY_CONTEXT_DEFAULT) + .build() + ) + .expectedResults( + ImmutableList.of( + new Object[]{1.0f, "a"}, + new Object[]{1.0f, "b"}, + new Object[]{2.0f, "a"}, + new Object[]{2.0f, "b"}, + new Object[]{3.0f, "a"}, + new Object[]{3.0f, "b"}, + new Object[]{4.0f, "a"}, + new Object[]{4.0f, "b"}, + new Object[]{5.0f, "a"}, + new Object[]{5.0f, "b"}, + new Object[]{6.0f, "a"}, + new Object[]{6.0f, "b"} + ) + ) + .run(); + } + + @Test + public void testStringToMVOfStringAgg() + { + // Cannot vectorize due to usage of expressions. + cannotVectorize(); + + final String expectedResult; + + if (NullHandling.sqlCompatible()) { + expectedResult = "[\"\",\"10.1\",\"2\",\"1\",\"def\",\"abc\"]"; + } else { + expectedResult = "[\"10.1\",\"2\",\"1\",\"def\",\"abc\"]"; + } + + testBuilder() + .sql("SELECT STRING_TO_MV(STRING_AGG(dim1, ','), ',') AS mv, COUNT(*) cnt FROM druid.numfoo") + .expectedQuery( + Druids.newTimeseriesQueryBuilder() + .dataSource(CalciteTests.DATASOURCE3) + .intervals(querySegmentSpec(Filtration.eternity())) + .aggregators( + new FilteredAggregatorFactory( + new ExpressionLambdaAggregatorFactory( + "a0", + ImmutableSet.of("dim1"), + "__acc", + "[]", + "[]", + true, + false, + false, + "array_append(\"__acc\", \"dim1\")", + "array_concat(\"__acc\", \"a0\")", + null, + "if(array_length(o) == 0, null, array_to_string(o, ','))", + ExpressionLambdaAggregatorFactory.DEFAULT_MAX_SIZE_BYTES, + TestExprMacroTable.INSTANCE + ), + not(selector("dim1", null, null)) + ), + new CountAggregatorFactory("a1") + ) + .postAggregators(expressionPostAgg("p0", "string_to_array(\"a0\",',')")) + .context(QUERY_CONTEXT_DEFAULT) + .build() + ) + .expectedResults(ImmutableList.of(new Object[]{expectedResult, 6L})) + .run(); + } @Test public void testMultiValueListFilter() diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/run/SqlResultsTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/run/SqlResultsTest.java index d19e218075a..954170fffeb 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/run/SqlResultsTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/run/SqlResultsTest.java @@ -19,16 +19,39 @@ package org.apache.druid.sql.calcite.run; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSortedSet; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.common.config.NullHandling; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.data.ComparableList; import org.apache.druid.segment.data.ComparableStringArray; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.joda.time.DateTimeZone; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; import java.util.Arrays; +import java.util.Collections; import java.util.List; -public class SqlResultsTest +public class SqlResultsTest extends InitializedNullHandlingTest { + private static final SqlResults.Context DEFAULT_CONTEXT = new SqlResults.Context(DateTimeZone.UTC, true, false); + + private ObjectMapper jsonMapper; + + @Before + public void setUp() + { + jsonMapper = TestHelper.JSON_MAPPER; + } @Test public void testCoerceStringArrays() @@ -38,10 +61,10 @@ public class SqlResultsTest final ComparableStringArray comparableStringArray = ComparableStringArray.of(new String[]{"x", "y", "z", null}); final String[] stringArray2 = new String[]{"x", "y", "z", null}; - assertCoerced(stringList, stringList, true); - assertCoerced(stringList, stringArray, true); - assertCoerced(stringList, stringArray2, true); - assertCoerced(stringList, comparableStringArray, true); + assertCoerceArrayToList(stringList, stringList); + assertCoerceArrayToList(stringList, stringArray); + assertCoerceArrayToList(stringList, stringArray2); + assertCoerceArrayToList(stringList, comparableStringArray); } @Test @@ -53,11 +76,11 @@ public class SqlResultsTest final List list = Arrays.asList(1L, 2L, 3L); final long[] array = new long[]{1L, 2L, 3L}; - assertCoerced(listWithNull, listWithNull, true); - assertCoerced(listWithNull, arrayWithNull, true); - assertCoerced(listWithNull, comparableList, true); - assertCoerced(list, list, true); - assertCoerced(list, array, true); + assertCoerceArrayToList(listWithNull, listWithNull); + assertCoerceArrayToList(listWithNull, arrayWithNull); + assertCoerceArrayToList(listWithNull, comparableList); + assertCoerceArrayToList(list, list); + assertCoerceArrayToList(list, array); } @Test @@ -69,11 +92,11 @@ public class SqlResultsTest final List list = Arrays.asList(1.1, 2.2, 3.3); final double[] array = new double[]{1.1, 2.2, 3.3}; - assertCoerced(listWithNull, listWithNull, true); - assertCoerced(listWithNull, arrayWithNull, true); - assertCoerced(listWithNull, comparableList, true); - assertCoerced(list, list, true); - assertCoerced(list, array, true); + assertCoerceArrayToList(listWithNull, listWithNull); + assertCoerceArrayToList(listWithNull, arrayWithNull); + assertCoerceArrayToList(listWithNull, comparableList); + assertCoerceArrayToList(list, list); + assertCoerceArrayToList(list, array); } @Test @@ -85,11 +108,11 @@ public class SqlResultsTest final List list = Arrays.asList(1.1f, 2.2f, 3.3f); final float[] array = new float[]{1.1f, 2.2f, 3.3f}; - assertCoerced(listWithNull, listWithNull, true); - assertCoerced(listWithNull, arrayWithNull, true); - assertCoerced(listWithNull, comparableList, true); - assertCoerced(list, list, true); - assertCoerced(list, array, true); + assertCoerceArrayToList(listWithNull, listWithNull); + assertCoerceArrayToList(listWithNull, arrayWithNull); + assertCoerceArrayToList(listWithNull, comparableList); + assertCoerceArrayToList(list, list); + assertCoerceArrayToList(list, array); } @Test @@ -98,8 +121,117 @@ public class SqlResultsTest List nestedList = Arrays.asList(Arrays.asList(1L, 2L, 3L), Arrays.asList(4L, 5L, 6L)); Object[] nestedArray = new Object[]{new Object[]{1L, 2L, 3L}, new Object[]{4L, 5L, 6L}}; - assertCoerced(nestedList, nestedList, true); - assertCoerced(nestedList, nestedArray, true); + assertCoerceArrayToList(nestedList, nestedList); + assertCoerceArrayToList(nestedList, nestedArray); + } + + @Test + public void testCoerceBoolean() + { + assertCoerce(false, false, SqlTypeName.BOOLEAN); + assertCoerce(false, "xyz", SqlTypeName.BOOLEAN); + assertCoerce(false, 0, SqlTypeName.BOOLEAN); + assertCoerce(false, "false", SqlTypeName.BOOLEAN); + assertCoerce(true, true, SqlTypeName.BOOLEAN); + assertCoerce(true, "true", SqlTypeName.BOOLEAN); + assertCoerce(true, 1, SqlTypeName.BOOLEAN); + assertCoerce(true, 1.0, SqlTypeName.BOOLEAN); + assertCoerce(null, null, SqlTypeName.BOOLEAN); + + assertCannotCoerce(Collections.emptyList(), SqlTypeName.BOOLEAN); + } + + @Test + public void testCoerceInteger() + { + assertCoerce(0, 0, SqlTypeName.INTEGER); + assertCoerce(1, 1L, SqlTypeName.INTEGER); + assertCoerce(1, 1f, SqlTypeName.INTEGER); + assertCoerce(1, "1", SqlTypeName.INTEGER); + assertCoerce(null, "1.1", SqlTypeName.INTEGER); + assertCoerce(null, "xyz", SqlTypeName.INTEGER); + assertCoerce(null, null, SqlTypeName.INTEGER); + + assertCannotCoerce(Collections.emptyList(), SqlTypeName.INTEGER); + assertCannotCoerce(false, SqlTypeName.INTEGER); + } + + @Test + public void testCoerceBigint() + { + assertCoerce(0L, 0, SqlTypeName.BIGINT); + assertCoerce(1L, 1L, SqlTypeName.BIGINT); + assertCoerce(1L, 1f, SqlTypeName.BIGINT); + assertCoerce(null, "1.1", SqlTypeName.BIGINT); + assertCoerce(null, "xyz", SqlTypeName.BIGINT); + assertCoerce(null, null, SqlTypeName.BIGINT); + + // Inconsistency with FLOAT, INTEGER, DOUBLE. + assertCoerce(0L, false, SqlTypeName.BIGINT); + assertCoerce(1L, true, SqlTypeName.BIGINT); + + assertCannotCoerce(Collections.emptyList(), SqlTypeName.BIGINT); + } + + @Test + public void testCoerceFloat() + { + assertCoerce(0f, 0, SqlTypeName.FLOAT); + assertCoerce(1f, 1L, SqlTypeName.FLOAT); + assertCoerce(1f, 1f, SqlTypeName.FLOAT); + assertCoerce(1.1f, "1.1", SqlTypeName.FLOAT); + assertCoerce(null, "xyz", SqlTypeName.FLOAT); + assertCoerce(null, null, SqlTypeName.FLOAT); + + assertCannotCoerce(Collections.emptyList(), SqlTypeName.FLOAT); + assertCannotCoerce(false, SqlTypeName.FLOAT); + } + + @Test + public void testCoerceDouble() + { + assertCoerce(0d, 0, SqlTypeName.DOUBLE); + assertCoerce(1d, 1L, SqlTypeName.DOUBLE); + assertCoerce(1d, 1f, SqlTypeName.DOUBLE); + assertCoerce(1.1d, "1.1", SqlTypeName.DOUBLE); + assertCoerce(null, "xyz", SqlTypeName.DOUBLE); + assertCoerce(null, null, SqlTypeName.DOUBLE); + + assertCannotCoerce(Collections.emptyList(), SqlTypeName.DOUBLE); + assertCannotCoerce(false, SqlTypeName.DOUBLE); + } + + @Test + public void testCoerceString() + { + assertCoerce(NullHandling.defaultStringValue(), null, SqlTypeName.VARCHAR); + assertCoerce("1", 1, SqlTypeName.VARCHAR); + assertCoerce("true", true, SqlTypeName.VARCHAR); + assertCoerce("abc", "abc", SqlTypeName.VARCHAR); + + assertCoerce("[\"abc\",\"def\"]", ImmutableList.of("abc", "def"), SqlTypeName.VARCHAR); + assertCoerce("[\"abc\",\"def\"]", ImmutableSortedSet.of("abc", "def"), SqlTypeName.VARCHAR); + assertCoerce("[\"abc\",\"def\"]", new String[]{"abc", "def"}, SqlTypeName.VARCHAR); + assertCoerce("[\"abc\",\"def\"]", new Object[]{"abc", "def"}, SqlTypeName.VARCHAR); + + assertCoerce("[\"abc\"]", ImmutableList.of("abc"), SqlTypeName.VARCHAR); + assertCoerce("[\"abc\"]", ImmutableSortedSet.of("abc"), SqlTypeName.VARCHAR); + assertCoerce("[\"abc\"]", new String[]{"abc"}, SqlTypeName.VARCHAR); + assertCoerce("[\"abc\"]", new Object[]{"abc"}, SqlTypeName.VARCHAR); + + assertCannotCoerce(new Object(), SqlTypeName.VARCHAR); + } + + @Test + public void testCoerceArrayFails() + { + assertCannotCoerce("xyz", SqlTypeName.ARRAY); + } + + @Test + public void testCoerceUnsupportedType() + { + assertCannotCoerce("xyz", SqlTypeName.VARBINARY); } @Test @@ -108,9 +240,35 @@ public class SqlResultsTest Assert.assertNull(SqlResults.maybeCoerceArrayToList("hello", true)); } - private static void assertCoerced(Object expected, Object toCoerce, boolean mustCoerce) + @Test + public void testMayNotCoerce() { - Object coerced = SqlResults.maybeCoerceArrayToList(toCoerce, mustCoerce); + Assert.assertEquals("hello", SqlResults.maybeCoerceArrayToList("hello", false)); + } + + private void assertCoerce(Object expected, Object toCoerce, SqlTypeName typeName) + { + Assert.assertEquals( + StringUtils.format("Coerce [%s] to [%s]", toCoerce, typeName), + expected, + SqlResults.coerce(jsonMapper, DEFAULT_CONTEXT, toCoerce, typeName) + ); + } + + private void assertCannotCoerce(Object toCoerce, SqlTypeName typeName) + { + final IllegalStateException e = Assert.assertThrows( + StringUtils.format("Coerce [%s] to [%s]", toCoerce, typeName), + IllegalStateException.class, + () -> SqlResults.coerce(jsonMapper, DEFAULT_CONTEXT, toCoerce, typeName) + ); + + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString("Cannot coerce"))); + } + + private static void assertCoerceArrayToList(Object expected, Object toCoerce) + { + Object coerced = SqlResults.maybeCoerceArrayToList(toCoerce, true); Assert.assertEquals(expected, coerced); } }