Fix up value types when creating range filters. (#15778)

Fixes a bug introduced in #15609, where queries involving filters on
TIME_FLOOR could encounter ClassCastException when comparing RangeValue
in CombineAndSimplifyBounds.

Prior to #15609, CombineAndSimplifyBounds would remove, rebuild, and
re-add all numeric range filters as part of consolidating numeric range
filters for the same column under the least restrictive type. #15609
included a change to only rebuild numeric range filters when a consolidation
opportunity actually arises. The bug was introduced because the unconditional
rebuild, as a side effect, masked the fact that in some cases range filters
would be created with string match values and a LONG match value type.

This patch changes the fixup to happen at the time the range filter is
initially created, rather than in CombineAndSimplifyBounds.
This commit is contained in:
Gian Merlino 2024-01-29 13:30:47 -08:00 committed by GitHub
parent 54d0e482dc
commit 38a1e827ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 106 additions and 17 deletions

View File

@ -965,17 +965,17 @@ public class Expressions
? new NotDimFilter(Ranges.interval(rangeRefKey, interval))
: Filtration.matchEverything();
case GREATER_THAN:
return Ranges.greaterThanOrEqualTo(rangeRefKey, String.valueOf(interval.getEndMillis()));
return Ranges.greaterThanOrEqualTo(rangeRefKey, interval.getEndMillis());
case GREATER_THAN_OR_EQUAL:
return isAligned
? Ranges.greaterThanOrEqualTo(rangeRefKey, String.valueOf(interval.getStartMillis()))
: Ranges.greaterThanOrEqualTo(rangeRefKey, String.valueOf(interval.getEndMillis()));
? Ranges.greaterThanOrEqualTo(rangeRefKey, interval.getStartMillis())
: Ranges.greaterThanOrEqualTo(rangeRefKey, interval.getEndMillis());
case LESS_THAN:
return isAligned
? Ranges.lessThan(rangeRefKey, String.valueOf(interval.getStartMillis()))
: Ranges.lessThan(rangeRefKey, String.valueOf(interval.getEndMillis()));
? Ranges.lessThan(rangeRefKey, interval.getStartMillis())
: Ranges.lessThan(rangeRefKey, interval.getEndMillis());
case LESS_THAN_OR_EQUAL:
return Ranges.lessThan(rangeRefKey, String.valueOf(interval.getEndMillis()));
return Ranges.lessThan(rangeRefKey, interval.getEndMillis());
default:
throw new IllegalStateException("Shouldn't have got here");
}

View File

@ -49,11 +49,6 @@ public class RangeValue implements Comparable<RangeValue>
return value;
}
public ColumnType getMatchValueType()
{
return matchValueType;
}
@Override
public int compareTo(RangeValue o)
{

View File

@ -28,6 +28,7 @@ import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.filter.RangeFilter;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.ValueType;
import org.joda.time.Interval;
import javax.annotation.Nullable;
@ -39,6 +40,7 @@ public class Ranges
* Negates single-ended Bound filters.
*
* @param range filter
*
* @return negated filter, or null if this range is double-ended.
*/
@Nullable
@ -133,11 +135,12 @@ public class Ranges
public static RangeFilter equalTo(final RangeRefKey rangeRefKey, final Object value)
{
final Object castValue = castVal(rangeRefKey, value);
return new RangeFilter(
rangeRefKey.getColumn(),
rangeRefKey.getMatchValueType(),
value,
value,
castValue,
castValue,
false,
false,
null
@ -149,7 +152,7 @@ public class Ranges
return new RangeFilter(
rangeRefKey.getColumn(),
rangeRefKey.getMatchValueType(),
value,
castVal(rangeRefKey, value),
null,
true,
false,
@ -162,7 +165,7 @@ public class Ranges
return new RangeFilter(
rangeRefKey.getColumn(),
rangeRefKey.getMatchValueType(),
value,
castVal(rangeRefKey, value),
null,
false,
false,
@ -176,7 +179,7 @@ public class Ranges
rangeRefKey.getColumn(),
rangeRefKey.getMatchValueType(),
null,
value,
castVal(rangeRefKey, value),
false,
true,
null
@ -189,7 +192,7 @@ public class Ranges
rangeRefKey.getColumn(),
rangeRefKey.getMatchValueType(),
null,
value,
castVal(rangeRefKey, value),
false,
false,
null
@ -213,4 +216,30 @@ public class Ranges
null
);
}
/**
* Casts a primitive value such that it matches the {@link RangeRefKey#getMatchValueType()} of a provided key.
* Leaves nonprimitive values as-is.
*/
private static Object castVal(final RangeRefKey rangeRefKey, final Object value)
{
if (value instanceof String || value instanceof Number || value == null) {
final ColumnType columnType = rangeRefKey.getMatchValueType();
if (columnType.is(ValueType.STRING) && (value instanceof String || value == null)) {
// Short-circuit to save creation of ExprEval.
return value;
} else if (columnType.is(ValueType.DOUBLE) && value instanceof Double) {
// Short-circuit to save creation of ExprEval.
return value;
} else if (columnType.is(ValueType.LONG) && value instanceof Long) {
// Short-circuit to save creation of ExprEval.
return value;
} else {
final ExpressionType expressionType = ExpressionType.fromColumnType(columnType);
return ExprEval.ofType(expressionType, value).valueOrDefault();
}
} else {
return value;
}
}
}

View File

@ -5995,6 +5995,50 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testCountStarWithFloorTimeFilter()
{
testQuery(
"SELECT COUNT(*) FROM druid.foo "
+ "WHERE FLOOR(__time TO DAY) >= TIMESTAMP '2000-01-01 00:00:00' AND "
+ "FLOOR(__time TO DAY) < TIMESTAMP '2001-01-01 00:00:00'",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01")))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L}
)
);
}
@Test
public void testCountStarWithMisalignedFloorTimeFilter()
{
testQuery(
"SELECT COUNT(*) FROM druid.foo "
+ "WHERE FLOOR(__time TO DAY) >= TIMESTAMP '2000-01-01 00:00:01' AND "
+ "FLOOR(__time TO DAY) < TIMESTAMP '2001-01-01 00:00:01'",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.of("2000-01-02/2001-01-02")))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L}
)
);
}
@Test
public void testCountStarWithTimeInIntervalFilter()
{
@ -6114,6 +6158,27 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testCountStarWithBetweenFloorTimeFilter()
{
testQuery(
"SELECT COUNT(*) FROM druid.foo "
+ "WHERE FLOOR(__time TO DAY) BETWEEN TIMESTAMP '2000-01-01 00:00:00' AND TIMESTAMP '2000-12-31 00:00:00'",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Intervals.of("2000-01-01/2001-01-01")))
.granularity(Granularities.ALL)
.aggregators(aggregators(new CountAggregatorFactory("a0")))
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{3L}
)
);
}
@Test
public void testCountStarWithBetweenTimeFilterUsingMillisecondsInStringLiterals()
{