Moving object contains to Bound for string/object matchers (#16241)

This commit is contained in:
Pranav 2024-05-23 07:56:04 -07:00 committed by GitHub
parent 12f79acc7e
commit 204a25d3e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 115 additions and 27 deletions

View File

@ -70,10 +70,7 @@ public class ImmutableFloatNode implements ImmutableNode<float[]>
this.numChildren = (short) (header & 0x7FFF); this.numChildren = (short) (header & 0x7FFF);
final int sizePosition = initialOffset + offsetFromInitial + HEADER_NUM_BYTES + 2 * numDims * Float.BYTES; final int sizePosition = initialOffset + offsetFromInitial + HEADER_NUM_BYTES + 2 * numDims * Float.BYTES;
int bitmapSize = data.getInt(sizePosition); int bitmapSize = data.getInt(sizePosition);
this.childrenOffset = initialOffset this.childrenOffset = sizePosition
+ offsetFromInitial
+ HEADER_NUM_BYTES
+ 2 * numDims * Float.BYTES
+ Integer.BYTES + Integer.BYTES
+ bitmapSize; + bitmapSize;
@ -98,10 +95,7 @@ public class ImmutableFloatNode implements ImmutableNode<float[]>
this.isLeaf = leaf; this.isLeaf = leaf;
final int sizePosition = initialOffset + offsetFromInitial + HEADER_NUM_BYTES + 2 * numDims * Float.BYTES; final int sizePosition = initialOffset + offsetFromInitial + HEADER_NUM_BYTES + 2 * numDims * Float.BYTES;
int bitmapSize = data.getInt(sizePosition); int bitmapSize = data.getInt(sizePosition);
this.childrenOffset = initialOffset this.childrenOffset = sizePosition
+ offsetFromInitial
+ HEADER_NUM_BYTES
+ 2 * numDims * Float.BYTES
+ Integer.BYTES + Integer.BYTES
+ bitmapSize; + bitmapSize;

View File

@ -24,6 +24,8 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.druid.annotations.SubclassesMustOverrideEqualsAndHashCode; import org.apache.druid.annotations.SubclassesMustOverrideEqualsAndHashCode;
import org.apache.druid.collections.spatial.ImmutableNode; import org.apache.druid.collections.spatial.ImmutableNode;
import javax.annotation.Nullable;
/** /**
*/ */
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@ -43,6 +45,14 @@ public interface Bound<TCoordinateArray, TPoint extends ImmutableNode<TCoordinat
boolean contains(TCoordinateArray coords); boolean contains(TCoordinateArray coords);
/***
* containsObj is mainly used to create object matechers on top custom/extensible spatial column,
* it receives it as object and corresponding implementations need to logic to unpack the objects and invoke contains
* @param input Takes an object spatial column as input
* @return boolean value if it falls within given bound
*/
boolean containsObj(@Nullable Object input);
Iterable<TPoint> filter(Iterable<TPoint> points); Iterable<TPoint> filter(Iterable<TPoint> points);
byte[] getCacheKey(); byte[] getCacheKey();

View File

@ -26,12 +26,15 @@ import com.google.common.base.Predicate;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.druid.collections.spatial.ImmutableFloatPoint; import org.apache.druid.collections.spatial.ImmutableFloatPoint;
import org.apache.druid.collections.spatial.ImmutableNode; import org.apache.druid.collections.spatial.ImmutableNode;
import org.apache.druid.segment.incremental.SpatialDimensionRowTransformer;
import javax.annotation.Nullable;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.Arrays; import java.util.Arrays;
import java.util.Objects; import java.util.Objects;
/** /**
*
*/ */
public class RectangularBound implements Bound<float[], ImmutableFloatPoint> public class RectangularBound implements Bound<float[], ImmutableFloatPoint>
{ {
@ -118,6 +121,19 @@ public class RectangularBound implements Bound<float[], ImmutableFloatPoint>
return true; return true;
} }
@Override
public boolean containsObj(@Nullable Object input)
{
if (input instanceof String) {
final float[] coordinate = SpatialDimensionRowTransformer.decode((String) input);
if (coordinate == null) {
return false;
}
return contains(coordinate);
}
return false;
}
@Override @Override
public Iterable<ImmutableFloatPoint> filter(Iterable<ImmutableFloatPoint> points) public Iterable<ImmutableFloatPoint> filter(Iterable<ImmutableFloatPoint> points)
{ {

View File

@ -43,7 +43,6 @@ import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.ColumnIndexCapabilities; import org.apache.druid.segment.column.ColumnIndexCapabilities;
import org.apache.druid.segment.column.ColumnIndexSupplier; import org.apache.druid.segment.column.ColumnIndexSupplier;
import org.apache.druid.segment.column.SimpleColumnIndexCapabilities; import org.apache.druid.segment.column.SimpleColumnIndexCapabilities;
import org.apache.druid.segment.incremental.SpatialDimensionRowTransformer;
import org.apache.druid.segment.index.AllUnknownBitmapColumnIndex; import org.apache.druid.segment.index.AllUnknownBitmapColumnIndex;
import org.apache.druid.segment.index.BitmapColumnIndex; import org.apache.druid.segment.index.BitmapColumnIndex;
import org.apache.druid.segment.index.semantic.SpatialIndex; import org.apache.druid.segment.index.semantic.SpatialIndex;
@ -174,8 +173,18 @@ public class SpatialFilter implements Filter
if (input == null) { if (input == null) {
return DruidPredicateMatch.UNKNOWN; return DruidPredicateMatch.UNKNOWN;
} }
final float[] coordinate = SpatialDimensionRowTransformer.decode(input); return DruidPredicateMatch.of(bound.containsObj(input));
return DruidPredicateMatch.of(bound.contains(coordinate)); };
}
@Override
public DruidObjectPredicate<Object> makeObjectPredicate()
{
return input -> {
if (input == null) {
return DruidPredicateMatch.UNKNOWN;
}
return DruidPredicateMatch.of(bound.containsObj(input));
}; };
} }

View File

@ -40,12 +40,18 @@ import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory; import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.TestObjectColumnSelector;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.filter.FilterTuning;
import org.apache.druid.query.filter.SpatialDimFilter; import org.apache.druid.query.filter.SpatialDimFilter;
import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryEngine; import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest; import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory; import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import org.apache.druid.query.timeseries.TimeseriesResultValue; import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.IncrementalIndexSegment; import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.IndexIO; import org.apache.druid.segment.IndexIO;
import org.apache.druid.segment.IndexMerger; import org.apache.druid.segment.IndexMerger;
@ -54,6 +60,9 @@ import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.QueryableIndexSegment; import org.apache.druid.segment.QueryableIndexSegment;
import org.apache.druid.segment.Segment; import org.apache.druid.segment.Segment;
import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.StringEncodingStrategy; import org.apache.druid.segment.column.StringEncodingStrategy;
import org.apache.druid.segment.data.FrontCodedIndexed; import org.apache.druid.segment.data.FrontCodedIndexed;
import org.apache.druid.segment.incremental.IncrementalIndex; import org.apache.druid.segment.incremental.IncrementalIndex;
@ -62,28 +71,31 @@ import org.apache.druid.segment.incremental.OnheapIncrementalIndex;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory; import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.testing.InitializedNullHandlingTest;
import org.joda.time.Interval; import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
import javax.annotation.Nullable;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Date;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.ThreadLocalRandom;
/** /**
*
*/ */
@RunWith(Parameterized.class) @RunWith(Parameterized.class)
public class SpatialFilterTest extends InitializedNullHandlingTest public class SpatialFilterTest extends InitializedNullHandlingTest
{ {
public static final int NUM_POINTS = 5000;
private static IndexMerger INDEX_MERGER = TestHelper.getTestIndexMergerV9(OffHeapMemorySegmentWriteOutMediumFactory.instance()); private static IndexMerger INDEX_MERGER = TestHelper.getTestIndexMergerV9(OffHeapMemorySegmentWriteOutMediumFactory.instance());
private static IndexIO INDEX_IO = TestHelper.getTestIndexIO(); private static IndexIO INDEX_IO = TestHelper.getTestIndexIO();
public static final int NUM_POINTS = 5000;
private static Interval DATA_INTERVAL = Intervals.of("2013-01-01/2013-01-07"); private static Interval DATA_INTERVAL = Intervals.of("2013-01-01/2013-01-07");
private static AggregatorFactory[] METRIC_AGGS = new AggregatorFactory[]{ private static AggregatorFactory[] METRIC_AGGS = new AggregatorFactory[]{
@ -92,6 +104,12 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
}; };
private static List<String> DIMS = Lists.newArrayList("dim", "lat", "long", "lat2", "long2"); private static List<String> DIMS = Lists.newArrayList("dim", "lat", "long", "lat2", "long2");
private final Segment segment;
public SpatialFilterTest(Segment segment)
{
this.segment = segment;
}
@Parameterized.Parameters @Parameterized.Parameters
public static Collection<?> constructorFeeder() throws IOException public static Collection<?> constructorFeeder() throws IOException
@ -517,7 +535,11 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
QueryableIndex mergedRealtime = INDEX_IO.loadIndex( QueryableIndex mergedRealtime = INDEX_IO.loadIndex(
INDEX_MERGER.mergeQueryableIndex( INDEX_MERGER.mergeQueryableIndex(
Arrays.asList(INDEX_IO.loadIndex(firstFile), INDEX_IO.loadIndex(secondFile), INDEX_IO.loadIndex(thirdFile)), Arrays.asList(
INDEX_IO.loadIndex(firstFile),
INDEX_IO.loadIndex(secondFile),
INDEX_IO.loadIndex(thirdFile)
),
true, true,
METRIC_AGGS, METRIC_AGGS,
mergedFile, mergedFile,
@ -534,13 +556,6 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
} }
} }
private final Segment segment;
public SpatialFilterTest(Segment segment)
{
this.segment = segment;
}
@Test @Test
public void testSpatialQuery() public void testSpatialQuery()
{ {
@ -567,9 +582,9 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
DateTimes.of("2013-01-01T00:00:00.000Z"), DateTimes.of("2013-01-01T00:00:00.000Z"),
new TimeseriesResultValue( new TimeseriesResultValue(
ImmutableMap.<String, Object>builder() ImmutableMap.<String, Object>builder()
.put("rows", 3L) .put("rows", 3L)
.put("val", 59L) .put("val", 59L)
.build() .build()
) )
) )
); );
@ -619,9 +634,9 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
DateTimes.of("2013-01-01T00:00:00.000Z"), DateTimes.of("2013-01-01T00:00:00.000Z"),
new TimeseriesResultValue( new TimeseriesResultValue(
ImmutableMap.<String, Object>builder() ImmutableMap.<String, Object>builder()
.put("rows", 1L) .put("rows", 1L)
.put("val", 13L) .put("val", 13L)
.build() .build()
) )
) )
); );
@ -742,4 +757,48 @@ public class SpatialFilterTest extends InitializedNullHandlingTest
{ {
EqualsVerifier.forClass(SpatialFilter.BoundDruidPredicateFactory.class).usingGetClass().verify(); EqualsVerifier.forClass(SpatialFilter.BoundDruidPredicateFactory.class).usingGetClass().verify();
} }
@Test
public void testSpatialFilter()
{
SpatialFilter spatialFilter = new SpatialFilter(
"test",
new RadiusBound(new float[]{0, 0}, 0f, 0),
new FilterTuning(false, 1, 1)
);
// String complex
Assert.assertTrue(spatialFilter.makeMatcher(new TestSpatialSelectorFactory("0,0")).matches(true));
// Unknown complex, invokes object predicate
Assert.assertFalse(spatialFilter.makeMatcher(new TestSpatialSelectorFactory(new Date())).matches(true));
Assert.assertFalse(spatialFilter.makeMatcher(new TestSpatialSelectorFactory(new Object())).matches(true));
}
static class TestSpatialSelectorFactory implements ColumnSelectorFactory
{
Object object;
public TestSpatialSelectorFactory(Object value)
{
object = value;
}
@Override
public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
{
return null;
}
@Override
public ColumnValueSelector makeColumnValueSelector(String columnName)
{
return new TestObjectColumnSelector(new Object[]{object});
}
@Nullable
@Override
public ColumnCapabilities getColumnCapabilities(String column)
{
return ColumnCapabilitiesImpl.createDefault().setType(ColumnType.UNKNOWN_COMPLEX);
}
}
} }