From 09e0eefdc39e5e5829546adf4c349005ade3efbd Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Mon, 8 Jul 2024 10:58:05 -0700 Subject: [PATCH] modify equality and typed in filter behavior for numeric match values on string columns (#16593) * fix equality and typed in filter behavior for numeric match values on string columns changes: * EqualityFilter and TypedInfilter numeric match values against string columns will now cast strings to numeric values instead of converting the numeric values directly to string for pure string equality, which is consistent with the casts which are eaten in the SQL layer, as well as classic druid behavior * added tests to cover numeric equality matching. Double match values in particular would fail to match the string values since `1.0` would become `'1.0'` which does not match `'1'`. --- .../druid/query/filter/EqualityFilter.java | 85 +++++++++++++++---- .../druid/query/filter/RangeFilter.java | 3 +- .../druid/query/filter/TypedInFilter.java | 66 ++++++++++++-- .../index/IndexedUtf8ValueIndexes.java | 2 +- .../segment/filter/EqualityFilterTests.java | 6 ++ .../druid/segment/filter/InFilterTests.java | 23 +++++ 6 files changed, 156 insertions(+), 29 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/filter/EqualityFilter.java b/processing/src/main/java/org/apache/druid/query/filter/EqualityFilter.java index 06506c64d1a..f7b2dd1cdb9 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/EqualityFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/EqualityFilter.java @@ -31,6 +31,7 @@ import com.google.common.collect.TreeRangeSet; import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.IAE; import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExprType; import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.filter.vector.VectorValueMatcher; @@ -43,12 +44,14 @@ import org.apache.druid.segment.ColumnInspector; import org.apache.druid.segment.ColumnProcessorFactory; import org.apache.druid.segment.ColumnProcessors; import org.apache.druid.segment.ColumnSelectorFactory; +import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.DimensionSelector; import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnIndexSupplier; import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.TypeSignature; import org.apache.druid.segment.column.TypeStrategy; +import org.apache.druid.segment.column.Types; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.PredicateValueMatcherFactory; @@ -244,8 +247,9 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt public VectorValueMatcher makeVectorMatcher(VectorColumnSelectorFactory factory) { final ColumnCapabilities capabilities = factory.getColumnCapabilities(column); - - if (matchValueType.isPrimitive() && (capabilities == null || capabilities.isPrimitive())) { + final boolean primitiveMatch = matchValueType.isPrimitive() && (capabilities == null || capabilities.isPrimitive()); + if (primitiveMatch && useSimpleEquality(capabilities, matchValueType)) { + // if possible, use simplified value matcher instead of predicate return ColumnProcessors.makeVectorProcessor( column, VectorValueMatcherColumnProcessorFactory.instance(), @@ -298,6 +302,20 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt ); } + /** + * Can the match value type be cast directly to column type for equality comparison? For non-numeric match types, we + * just use exact string equality regardless of the column type. For numeric match value types against string columns, + * we instead cast the string to the match value type number for matching equality. + */ + public static boolean useSimpleEquality(TypeSignature columnType, ColumnType matchValueType) + { + if (Types.is(columnType, ValueType.STRING)) { + return !matchValueType.isNumeric(); + } + return true; + } + + @Nullable public static BitmapColumnIndex getEqualityIndex( String column, ExprEval matchValueEval, @@ -311,20 +329,22 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt return new AllUnknownBitmapColumnIndex(selector); } - final ValueIndexes valueIndexes = indexSupplier.as(ValueIndexes.class); - if (valueIndexes != null) { - // matchValueEval.value() cannot be null here due to check in the constructor - //noinspection DataFlowIssue - return valueIndexes.forValue(matchValueEval.value(), matchValueType); - } + if (useSimpleEquality(selector.getColumnCapabilities(column), matchValueType)) { + final ValueIndexes valueIndexes = indexSupplier.as(ValueIndexes.class); + if (valueIndexes != null) { + // matchValueEval.value() cannot be null here due to check in the constructor + //noinspection DataFlowIssue + return valueIndexes.forValue(matchValueEval.value(), matchValueType); + } + if (matchValueType.isPrimitive()) { + final StringValueSetIndexes stringValueSetIndexes = indexSupplier.as(StringValueSetIndexes.class); + if (stringValueSetIndexes != null) { - if (matchValueType.isPrimitive()) { - final StringValueSetIndexes stringValueSetIndexes = indexSupplier.as(StringValueSetIndexes.class); - if (stringValueSetIndexes != null) { - - return stringValueSetIndexes.forValue(matchValueEval.asString()); + return stringValueSetIndexes.forValue(matchValueEval.asString()); + } } } + // fall back to predicate based index if it is available final DruidPredicateIndexes predicateIndexes = indexSupplier.as(DruidPredicateIndexes.class); if (predicateIndexes != null) { @@ -408,11 +428,38 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt private Supplier> makeStringPredicateSupplier() { return Suppliers.memoize(() -> { - final ExprEval castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING); - if (castForComparison == null) { - return DruidObjectPredicate.alwaysFalseWithNullUnknown(); + // when matching strings to numeric match values, use numeric comparator to implicitly cast the string to number + if (matchValue.type().isNumeric()) { + if (matchValue.type().is(ExprType.LONG)) { + return value -> { + if (value == null) { + return DruidPredicateMatch.UNKNOWN; + } + final Long l = DimensionHandlerUtils.convertObjectToLong(value); + if (l == null) { + return DruidPredicateMatch.FALSE; + } + return DruidPredicateMatch.of(matchValue.asLong() == l); + }; + } else { + return value -> { + if (value == null) { + return DruidPredicateMatch.UNKNOWN; + } + final Double d = DimensionHandlerUtils.convertObjectToDouble(value); + if (d == null) { + return DruidPredicateMatch.FALSE; + } + return DruidPredicateMatch.of(matchValue.asDouble() == d); + }; + } + } else { + final ExprEval castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING); + if (castForComparison == null) { + return DruidObjectPredicate.alwaysFalseWithNullUnknown(); + } + return DruidObjectPredicate.equalTo(castForComparison.asString()); } - return DruidObjectPredicate.equalTo(castForComparison.asString()); }); } @@ -548,6 +595,10 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt @Override public ValueMatcher makeDimensionProcessor(DimensionSelector selector, boolean multiValue) { + // use the predicate matcher when matching numeric values since it casts the strings to numeric types + if (matchValue.type().isNumeric()) { + return predicateMatcherFactory.makeDimensionProcessor(selector, multiValue); + } final ExprEval castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING); if (castForComparison == null) { return ValueMatchers.makeAlwaysFalseWithNullUnknownDimensionMatcher(selector, multiValue); diff --git a/processing/src/main/java/org/apache/druid/query/filter/RangeFilter.java b/processing/src/main/java/org/apache/druid/query/filter/RangeFilter.java index 63fc48559ac..527b5912208 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/RangeFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/RangeFilter.java @@ -314,8 +314,7 @@ public class RangeFilter extends AbstractOptimizableDimFilter implements Filter final String upper = hasUpperBound() ? upperEval.asString() : null; return rangeIndexes.forRange(lower, lowerOpen, upper, upperOpen); } - } - if (matchValueType.isNumeric()) { + } else if (matchValueType.isNumeric()) { final NumericRangeIndexes rangeIndexes = indexSupplier.as(NumericRangeIndexes.class); if (rangeIndexes != null) { final Number lower = (Number) lowerEval.value(); diff --git a/processing/src/main/java/org/apache/druid/query/filter/TypedInFilter.java b/processing/src/main/java/org/apache/druid/query/filter/TypedInFilter.java index 63e3fbd4541..1230b522111 100644 --- a/processing/src/main/java/org/apache/druid/query/filter/TypedInFilter.java +++ b/processing/src/main/java/org/apache/druid/query/filter/TypedInFilter.java @@ -36,15 +36,21 @@ import com.google.common.collect.Sets; import com.google.common.collect.TreeRangeSet; import com.google.common.hash.Hasher; import com.google.common.hash.Hashing; +import com.google.common.primitives.Doubles; import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import it.unimi.dsi.fastutil.doubles.DoubleSet; import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongSet; import it.unimi.dsi.fastutil.objects.ObjectArrays; import org.apache.druid.common.config.NullHandling; +import org.apache.druid.common.guava.GuavaUtils; import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.math.expr.Evals; +import org.apache.druid.math.expr.ExprEval; +import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.filter.vector.VectorValueMatcher; import org.apache.druid.query.filter.vector.VectorValueMatcherColumnProcessorFactory; @@ -301,9 +307,11 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte } } - final ValueSetIndexes valueSetIndexes = indexSupplier.as(ValueSetIndexes.class); - if (valueSetIndexes != null) { - return valueSetIndexes.forSortedValues(sortedMatchValues.get(), matchValueType); + if (EqualityFilter.useSimpleEquality(selector.getColumnCapabilities(column), matchValueType)) { + final ValueSetIndexes valueSetIndexes = indexSupplier.as(ValueSetIndexes.class); + if (valueSetIndexes != null) { + return valueSetIndexes.forSortedValues(sortedMatchValues.get(), matchValueType); + } } return Filters.makePredicateIndex( @@ -452,20 +460,20 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte } @Nullable - private static Object coerceValue(@Nullable Object o, ColumnType matchValueType) + private static T coerceValue(@Nullable Object o, ColumnType matchValueType) { if (o == null) { return null; } switch (matchValueType.getType()) { case STRING: - return DimensionHandlerUtils.convertObjectToString(o); + return (T) DimensionHandlerUtils.convertObjectToString(o); case LONG: - return DimensionHandlerUtils.convertObjectToLong(o); + return (T) DimensionHandlerUtils.convertObjectToLong(o); case FLOAT: - return DimensionHandlerUtils.convertObjectToFloat(o); + return (T) DimensionHandlerUtils.convertObjectToFloat(o); case DOUBLE: - return DimensionHandlerUtils.convertObjectToDouble(o); + return (T) DimensionHandlerUtils.convertObjectToDouble(o); default: throw InvalidInput.exception("Unsupported matchValueType[%s]", matchValueType); } @@ -540,11 +548,51 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte final int index = Collections.binarySearch(sortedValues, value, comparator); return DruidPredicateMatch.of(index >= 0); }; + } else if (matchValueType.is(ValueType.LONG)) { + final LongSet valueSet = new LongOpenHashSet(sortedValues.size()); + for (Object o : sortedValues) { + final Long l = DimensionHandlerUtils.convertObjectToLong(o); + if (l != null) { + valueSet.add(l.longValue()); + } + } + return value -> { + if (value == null) { + return containsNull ? DruidPredicateMatch.TRUE : DruidPredicateMatch.UNKNOWN; + } + final Long castValue = GuavaUtils.tryParseLong(value); + if (castValue == null) { + return DruidPredicateMatch.FALSE; + } + return DruidPredicateMatch.of(valueSet.contains(castValue)); + }; + } else if (matchValueType.isNumeric()) { + // double or float + final DoubleSet valueSet = new DoubleOpenHashSet(sortedValues.size()); + for (Object o : sortedValues) { + Double d = DimensionHandlerUtils.convertObjectToDouble(o); + if (d != null) { + valueSet.add(d.doubleValue()); + } + } + return value -> { + if (value == null) { + return containsNull ? DruidPredicateMatch.TRUE : DruidPredicateMatch.UNKNOWN; + } + + final Double d = Doubles.tryParse(value); + if (d == null) { + return DruidPredicateMatch.FALSE; + } + return DruidPredicateMatch.of(valueSet.contains(d)); + }; } + // convert set to strings + final ExpressionType matchExpressionType = ExpressionType.fromColumnTypeStrict(matchValueType); final Set stringSet = Sets.newHashSetWithExpectedSize(sortedValues.size()); for (Object o : sortedValues) { - stringSet.add(Evals.asString(o)); + stringSet.add(ExprEval.ofType(matchExpressionType, o).castTo(ExpressionType.STRING).asString()); } return value -> { if (value == null) { diff --git a/processing/src/main/java/org/apache/druid/segment/index/IndexedUtf8ValueIndexes.java b/processing/src/main/java/org/apache/druid/segment/index/IndexedUtf8ValueIndexes.java index 65395d148b2..6015088d558 100644 --- a/processing/src/main/java/org/apache/druid/segment/index/IndexedUtf8ValueIndexes.java +++ b/processing/src/main/java/org/apache/druid/segment/index/IndexedUtf8ValueIndexes.java @@ -230,7 +230,7 @@ public final class IndexedUtf8ValueIndexes= 0 ? position : -(position + 1), baseSet.size()); diff --git a/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java b/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java index 3d35817531f..fd87969a042 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/EqualityFilterTests.java @@ -108,6 +108,12 @@ public class EqualityFilterTests NotDimFilter.of(new EqualityFilter("dim0", ColumnType.LONG, 1L, null)), ImmutableList.of("0", "2", "3", "4", "5") ); + + assertFilterMatches(new EqualityFilter("dim0", ColumnType.DOUBLE, 1, null), ImmutableList.of("1")); + assertFilterMatches( + NotDimFilter.of(new EqualityFilter("dim0", ColumnType.DOUBLE, 1, null)), + ImmutableList.of("0", "2", "3", "4", "5") + ); } @Test diff --git a/processing/src/test/java/org/apache/druid/segment/filter/InFilterTests.java b/processing/src/test/java/org/apache/druid/segment/filter/InFilterTests.java index fd8c79096c4..6f5c4b72eb1 100644 --- a/processing/src/test/java/org/apache/druid/segment/filter/InFilterTests.java +++ b/processing/src/test/java/org/apache/druid/segment/filter/InFilterTests.java @@ -138,6 +138,29 @@ public class InFilterTests NotDimFilter.of(inFilter("dim0", ColumnType.STRING, Arrays.asList("e", "x"))), ImmutableList.of("a", "b", "c", "d", "f") ); + + if (NullHandling.sqlCompatible()) { + assertTypedFilterMatches( + inFilter("dim1", ColumnType.LONG, Arrays.asList(2L, 10L)), + ImmutableList.of("b", "c") + ); + + assertTypedFilterMatches( + inFilter("dim1", ColumnType.DOUBLE, Arrays.asList(2.0, 10.0)), + ImmutableList.of("b", "c") + ); + } else { + // in default value mode, we actually end up using a classic InDimFilter, it does not match numbers well + assertTypedFilterMatches( + inFilter("dim1", ColumnType.LONG, Arrays.asList(2L, 10L)), + ImmutableList.of("b", "c") + ); + + assertTypedFilterMatches( + inFilter("dim1", ColumnType.DOUBLE, Arrays.asList(2.0, 10.0)), + ImmutableList.of() + ); + } } @Test public void testSingleValueStringColumnWithNulls()