modify equality and typed in filter behavior for numeric match values on string columns (#16593)

* fix equality and typed in filter behavior for numeric match values on string columns
changes:
* EqualityFilter and TypedInfilter numeric match values against string columns will now cast strings to numeric values instead of converting the numeric values directly to string for pure string equality, which is consistent with the casts which are eaten in the SQL layer, as well as classic druid behavior
* added tests to cover numeric equality matching. Double match values in particular would fail to match the string values since `1.0` would become `'1.0'` which does not match `'1'`.
This commit is contained in:
Clint Wylie 2024-07-08 10:58:05 -07:00 committed by GitHub
parent 7c6f2b1e20
commit 09e0eefdc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 156 additions and 29 deletions

View File

@ -31,6 +31,7 @@ import com.google.common.collect.TreeRangeSet;
import org.apache.druid.error.InvalidInput; import org.apache.druid.error.InvalidInput;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.math.expr.ExprEval; import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionType; import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.filter.vector.VectorValueMatcher; import org.apache.druid.query.filter.vector.VectorValueMatcher;
@ -43,12 +44,14 @@ import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnProcessorFactory; import org.apache.druid.segment.ColumnProcessorFactory;
import org.apache.druid.segment.ColumnProcessors; import org.apache.druid.segment.ColumnProcessors;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.DimensionSelector; import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnIndexSupplier; import org.apache.druid.segment.column.ColumnIndexSupplier;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.TypeSignature; import org.apache.druid.segment.column.TypeSignature;
import org.apache.druid.segment.column.TypeStrategy; import org.apache.druid.segment.column.TypeStrategy;
import org.apache.druid.segment.column.Types;
import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.column.ValueType;
import org.apache.druid.segment.filter.Filters; import org.apache.druid.segment.filter.Filters;
import org.apache.druid.segment.filter.PredicateValueMatcherFactory; import org.apache.druid.segment.filter.PredicateValueMatcherFactory;
@ -244,8 +247,9 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
public VectorValueMatcher makeVectorMatcher(VectorColumnSelectorFactory factory) public VectorValueMatcher makeVectorMatcher(VectorColumnSelectorFactory factory)
{ {
final ColumnCapabilities capabilities = factory.getColumnCapabilities(column); final ColumnCapabilities capabilities = factory.getColumnCapabilities(column);
final boolean primitiveMatch = matchValueType.isPrimitive() && (capabilities == null || capabilities.isPrimitive());
if (matchValueType.isPrimitive() && (capabilities == null || capabilities.isPrimitive())) { if (primitiveMatch && useSimpleEquality(capabilities, matchValueType)) {
// if possible, use simplified value matcher instead of predicate
return ColumnProcessors.makeVectorProcessor( return ColumnProcessors.makeVectorProcessor(
column, column,
VectorValueMatcherColumnProcessorFactory.instance(), VectorValueMatcherColumnProcessorFactory.instance(),
@ -298,6 +302,20 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
); );
} }
/**
* Can the match value type be cast directly to column type for equality comparison? For non-numeric match types, we
* just use exact string equality regardless of the column type. For numeric match value types against string columns,
* we instead cast the string to the match value type number for matching equality.
*/
public static boolean useSimpleEquality(TypeSignature<ValueType> columnType, ColumnType matchValueType)
{
if (Types.is(columnType, ValueType.STRING)) {
return !matchValueType.isNumeric();
}
return true;
}
@Nullable
public static BitmapColumnIndex getEqualityIndex( public static BitmapColumnIndex getEqualityIndex(
String column, String column,
ExprEval<?> matchValueEval, ExprEval<?> matchValueEval,
@ -311,13 +329,13 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
return new AllUnknownBitmapColumnIndex(selector); return new AllUnknownBitmapColumnIndex(selector);
} }
if (useSimpleEquality(selector.getColumnCapabilities(column), matchValueType)) {
final ValueIndexes valueIndexes = indexSupplier.as(ValueIndexes.class); final ValueIndexes valueIndexes = indexSupplier.as(ValueIndexes.class);
if (valueIndexes != null) { if (valueIndexes != null) {
// matchValueEval.value() cannot be null here due to check in the constructor // matchValueEval.value() cannot be null here due to check in the constructor
//noinspection DataFlowIssue //noinspection DataFlowIssue
return valueIndexes.forValue(matchValueEval.value(), matchValueType); return valueIndexes.forValue(matchValueEval.value(), matchValueType);
} }
if (matchValueType.isPrimitive()) { if (matchValueType.isPrimitive()) {
final StringValueSetIndexes stringValueSetIndexes = indexSupplier.as(StringValueSetIndexes.class); final StringValueSetIndexes stringValueSetIndexes = indexSupplier.as(StringValueSetIndexes.class);
if (stringValueSetIndexes != null) { if (stringValueSetIndexes != null) {
@ -325,6 +343,8 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
return stringValueSetIndexes.forValue(matchValueEval.asString()); return stringValueSetIndexes.forValue(matchValueEval.asString());
} }
} }
}
// fall back to predicate based index if it is available // fall back to predicate based index if it is available
final DruidPredicateIndexes predicateIndexes = indexSupplier.as(DruidPredicateIndexes.class); final DruidPredicateIndexes predicateIndexes = indexSupplier.as(DruidPredicateIndexes.class);
if (predicateIndexes != null) { if (predicateIndexes != null) {
@ -408,11 +428,38 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
private Supplier<DruidObjectPredicate<String>> makeStringPredicateSupplier() private Supplier<DruidObjectPredicate<String>> makeStringPredicateSupplier()
{ {
return Suppliers.memoize(() -> { return Suppliers.memoize(() -> {
// when matching strings to numeric match values, use numeric comparator to implicitly cast the string to number
if (matchValue.type().isNumeric()) {
if (matchValue.type().is(ExprType.LONG)) {
return value -> {
if (value == null) {
return DruidPredicateMatch.UNKNOWN;
}
final Long l = DimensionHandlerUtils.convertObjectToLong(value);
if (l == null) {
return DruidPredicateMatch.FALSE;
}
return DruidPredicateMatch.of(matchValue.asLong() == l);
};
} else {
return value -> {
if (value == null) {
return DruidPredicateMatch.UNKNOWN;
}
final Double d = DimensionHandlerUtils.convertObjectToDouble(value);
if (d == null) {
return DruidPredicateMatch.FALSE;
}
return DruidPredicateMatch.of(matchValue.asDouble() == d);
};
}
} else {
final ExprEval<?> castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING); final ExprEval<?> castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING);
if (castForComparison == null) { if (castForComparison == null) {
return DruidObjectPredicate.alwaysFalseWithNullUnknown(); return DruidObjectPredicate.alwaysFalseWithNullUnknown();
} }
return DruidObjectPredicate.equalTo(castForComparison.asString()); return DruidObjectPredicate.equalTo(castForComparison.asString());
}
}); });
} }
@ -548,6 +595,10 @@ public class EqualityFilter extends AbstractOptimizableDimFilter implements Filt
@Override @Override
public ValueMatcher makeDimensionProcessor(DimensionSelector selector, boolean multiValue) public ValueMatcher makeDimensionProcessor(DimensionSelector selector, boolean multiValue)
{ {
// use the predicate matcher when matching numeric values since it casts the strings to numeric types
if (matchValue.type().isNumeric()) {
return predicateMatcherFactory.makeDimensionProcessor(selector, multiValue);
}
final ExprEval<?> castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING); final ExprEval<?> castForComparison = ExprEval.castForEqualityComparison(matchValue, ExpressionType.STRING);
if (castForComparison == null) { if (castForComparison == null) {
return ValueMatchers.makeAlwaysFalseWithNullUnknownDimensionMatcher(selector, multiValue); return ValueMatchers.makeAlwaysFalseWithNullUnknownDimensionMatcher(selector, multiValue);

View File

@ -314,8 +314,7 @@ public class RangeFilter extends AbstractOptimizableDimFilter implements Filter
final String upper = hasUpperBound() ? upperEval.asString() : null; final String upper = hasUpperBound() ? upperEval.asString() : null;
return rangeIndexes.forRange(lower, lowerOpen, upper, upperOpen); return rangeIndexes.forRange(lower, lowerOpen, upper, upperOpen);
} }
} } else if (matchValueType.isNumeric()) {
if (matchValueType.isNumeric()) {
final NumericRangeIndexes rangeIndexes = indexSupplier.as(NumericRangeIndexes.class); final NumericRangeIndexes rangeIndexes = indexSupplier.as(NumericRangeIndexes.class);
if (rangeIndexes != null) { if (rangeIndexes != null) {
final Number lower = (Number) lowerEval.value(); final Number lower = (Number) lowerEval.value();

View File

@ -36,15 +36,21 @@ import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeSet; import com.google.common.collect.TreeRangeSet;
import com.google.common.hash.Hasher; import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing; import com.google.common.hash.Hashing;
import com.google.common.primitives.Doubles;
import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
import it.unimi.dsi.fastutil.doubles.DoubleSet;
import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import it.unimi.dsi.fastutil.objects.ObjectArrays; import it.unimi.dsi.fastutil.objects.ObjectArrays;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.common.guava.GuavaUtils;
import org.apache.druid.error.InvalidInput; import org.apache.druid.error.InvalidInput;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.Evals; import org.apache.druid.math.expr.Evals;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.query.cache.CacheKeyBuilder; import org.apache.druid.query.cache.CacheKeyBuilder;
import org.apache.druid.query.filter.vector.VectorValueMatcher; import org.apache.druid.query.filter.vector.VectorValueMatcher;
import org.apache.druid.query.filter.vector.VectorValueMatcherColumnProcessorFactory; import org.apache.druid.query.filter.vector.VectorValueMatcherColumnProcessorFactory;
@ -301,10 +307,12 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte
} }
} }
if (EqualityFilter.useSimpleEquality(selector.getColumnCapabilities(column), matchValueType)) {
final ValueSetIndexes valueSetIndexes = indexSupplier.as(ValueSetIndexes.class); final ValueSetIndexes valueSetIndexes = indexSupplier.as(ValueSetIndexes.class);
if (valueSetIndexes != null) { if (valueSetIndexes != null) {
return valueSetIndexes.forSortedValues(sortedMatchValues.get(), matchValueType); return valueSetIndexes.forSortedValues(sortedMatchValues.get(), matchValueType);
} }
}
return Filters.makePredicateIndex( return Filters.makePredicateIndex(
column, column,
@ -452,20 +460,20 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte
} }
@Nullable @Nullable
private static Object coerceValue(@Nullable Object o, ColumnType matchValueType) private static <T> T coerceValue(@Nullable Object o, ColumnType matchValueType)
{ {
if (o == null) { if (o == null) {
return null; return null;
} }
switch (matchValueType.getType()) { switch (matchValueType.getType()) {
case STRING: case STRING:
return DimensionHandlerUtils.convertObjectToString(o); return (T) DimensionHandlerUtils.convertObjectToString(o);
case LONG: case LONG:
return DimensionHandlerUtils.convertObjectToLong(o); return (T) DimensionHandlerUtils.convertObjectToLong(o);
case FLOAT: case FLOAT:
return DimensionHandlerUtils.convertObjectToFloat(o); return (T) DimensionHandlerUtils.convertObjectToFloat(o);
case DOUBLE: case DOUBLE:
return DimensionHandlerUtils.convertObjectToDouble(o); return (T) DimensionHandlerUtils.convertObjectToDouble(o);
default: default:
throw InvalidInput.exception("Unsupported matchValueType[%s]", matchValueType); throw InvalidInput.exception("Unsupported matchValueType[%s]", matchValueType);
} }
@ -540,11 +548,51 @@ public class TypedInFilter extends AbstractOptimizableDimFilter implements Filte
final int index = Collections.binarySearch(sortedValues, value, comparator); final int index = Collections.binarySearch(sortedValues, value, comparator);
return DruidPredicateMatch.of(index >= 0); return DruidPredicateMatch.of(index >= 0);
}; };
} else if (matchValueType.is(ValueType.LONG)) {
final LongSet valueSet = new LongOpenHashSet(sortedValues.size());
for (Object o : sortedValues) {
final Long l = DimensionHandlerUtils.convertObjectToLong(o);
if (l != null) {
valueSet.add(l.longValue());
} }
}
return value -> {
if (value == null) {
return containsNull ? DruidPredicateMatch.TRUE : DruidPredicateMatch.UNKNOWN;
}
final Long castValue = GuavaUtils.tryParseLong(value);
if (castValue == null) {
return DruidPredicateMatch.FALSE;
}
return DruidPredicateMatch.of(valueSet.contains(castValue));
};
} else if (matchValueType.isNumeric()) {
// double or float
final DoubleSet valueSet = new DoubleOpenHashSet(sortedValues.size());
for (Object o : sortedValues) {
Double d = DimensionHandlerUtils.convertObjectToDouble(o);
if (d != null) {
valueSet.add(d.doubleValue());
}
}
return value -> {
if (value == null) {
return containsNull ? DruidPredicateMatch.TRUE : DruidPredicateMatch.UNKNOWN;
}
final Double d = Doubles.tryParse(value);
if (d == null) {
return DruidPredicateMatch.FALSE;
}
return DruidPredicateMatch.of(valueSet.contains(d));
};
}
// convert set to strings // convert set to strings
final ExpressionType matchExpressionType = ExpressionType.fromColumnTypeStrict(matchValueType);
final Set<String> stringSet = Sets.newHashSetWithExpectedSize(sortedValues.size()); final Set<String> stringSet = Sets.newHashSetWithExpectedSize(sortedValues.size());
for (Object o : sortedValues) { for (Object o : sortedValues) {
stringSet.add(Evals.asString(o)); stringSet.add(ExprEval.ofType(matchExpressionType, o).castTo(ExpressionType.STRING).asString());
} }
return value -> { return value -> {
if (value == null) { if (value == null) {

View File

@ -230,7 +230,7 @@ public final class IndexedUtf8ValueIndexes<TDictionary extends Indexed<ByteBuffe
final Object minValueInColumn = dictionary.get(0); final Object minValueInColumn = dictionary.get(0);
final int position = Collections.binarySearch( final int position = Collections.binarySearch(
sortedValues, sortedValues,
StringUtils.fromUtf8((ByteBuffer) minValueInColumn), StringUtils.fromUtf8Nullable((ByteBuffer) minValueInColumn),
matchValueType.getNullableStrategy() matchValueType.getNullableStrategy()
); );
tailSet = baseSet.subList(position >= 0 ? position : -(position + 1), baseSet.size()); tailSet = baseSet.subList(position >= 0 ? position : -(position + 1), baseSet.size());

View File

@ -108,6 +108,12 @@ public class EqualityFilterTests
NotDimFilter.of(new EqualityFilter("dim0", ColumnType.LONG, 1L, null)), NotDimFilter.of(new EqualityFilter("dim0", ColumnType.LONG, 1L, null)),
ImmutableList.of("0", "2", "3", "4", "5") ImmutableList.of("0", "2", "3", "4", "5")
); );
assertFilterMatches(new EqualityFilter("dim0", ColumnType.DOUBLE, 1, null), ImmutableList.of("1"));
assertFilterMatches(
NotDimFilter.of(new EqualityFilter("dim0", ColumnType.DOUBLE, 1, null)),
ImmutableList.of("0", "2", "3", "4", "5")
);
} }
@Test @Test

View File

@ -138,6 +138,29 @@ public class InFilterTests
NotDimFilter.of(inFilter("dim0", ColumnType.STRING, Arrays.asList("e", "x"))), NotDimFilter.of(inFilter("dim0", ColumnType.STRING, Arrays.asList("e", "x"))),
ImmutableList.of("a", "b", "c", "d", "f") ImmutableList.of("a", "b", "c", "d", "f")
); );
if (NullHandling.sqlCompatible()) {
assertTypedFilterMatches(
inFilter("dim1", ColumnType.LONG, Arrays.asList(2L, 10L)),
ImmutableList.of("b", "c")
);
assertTypedFilterMatches(
inFilter("dim1", ColumnType.DOUBLE, Arrays.asList(2.0, 10.0)),
ImmutableList.of("b", "c")
);
} else {
// in default value mode, we actually end up using a classic InDimFilter, it does not match numbers well
assertTypedFilterMatches(
inFilter("dim1", ColumnType.LONG, Arrays.asList(2L, 10L)),
ImmutableList.of("b", "c")
);
assertTypedFilterMatches(
inFilter("dim1", ColumnType.DOUBLE, Arrays.asList(2.0, 10.0)),
ImmutableList.of()
);
}
} }
@Test @Test
public void testSingleValueStringColumnWithNulls() public void testSingleValueStringColumnWithNulls()