From 8ba7f6a48c95074a6632b14fa93447487923b18c Mon Sep 17 00:00:00 2001 From: Jihoon Son Date: Sat, 31 Jul 2021 15:55:49 -0700 Subject: [PATCH] Fix incorrect result of exact topN on an inner join with limit (#11517) --- .../druid/query/topn/BaseTopNAlgorithm.java | 1 + .../druid/query/topn/TopNQueryConfig.java | 4 +- .../apache/druid/segment/StorageAdapter.java | 13 +++ .../join/HashJoinSegmentStorageAdapter.java | 9 +- ...BaseHashJoinSegmentStorageAdapterTest.java | 7 +- .../HashJoinSegmentStorageAdapterTest.java | 97 +++++++++++++++++++ ...yRunnerBasedOnClusteredClientTestBase.java | 7 +- .../apache/druid/server/QueryStackTests.java | 49 ++++++++-- .../sql/calcite/BaseCalciteQueryTest.java | 4 +- .../druid/sql/calcite/CalciteQueryTest.java | 50 ++++++++++ 10 files changed, 227 insertions(+), 14 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java b/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java index b8b04adb771..843d248221e 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java +++ b/processing/src/main/java/org/apache/druid/query/topn/BaseTopNAlgorithm.java @@ -316,6 +316,7 @@ public abstract class BaseTopNAlgorithm interval.contains(storageAdapter.getInterval()))) { endIndex = Math.min(endIndex, startIndex + query.getThreshold()); } diff --git a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java index 9e5dcd31744..2793b270b8a 100644 --- a/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java +++ b/processing/src/main/java/org/apache/druid/query/topn/TopNQueryConfig.java @@ -27,9 +27,11 @@ import javax.validation.constraints.Min; */ public class TopNQueryConfig { + public static final int DEFAULT_MIN_TOPN_THRESHOLD = 1000; + @JsonProperty @Min(1) - private int minTopNThreshold = 1000; + private int minTopNThreshold = DEFAULT_MIN_TOPN_THRESHOLD; public int getMinTopNThreshold() { diff --git a/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java index e7905b2902e..2aa4b774506 100644 --- a/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/StorageAdapter.java @@ -76,4 +76,17 @@ public interface StorageAdapter extends CursorFactory, ColumnInspector int getNumRows(); DateTime getMaxIngestedEventTime(); Metadata getMetadata(); + + /** + * Returns true if this storage adapter can filter some rows out. The actual column cardinality can be lower than + * what {@link #getDimensionCardinality} returns if this returns true. Dimension selectors for such storage adapter + * can return non-contiguous dictionary IDs because the dictionary IDs in filtered rows will not be returned. + * Note that the number of rows accessible via this storage adapter will not necessarily decrease because of + * the built-in filters. For inner joins, for example, the number of joined rows can be larger than + * the number of rows in the base adapter even though this method returns true. + */ + default boolean hasBuiltInFilters() + { + return false; + } } diff --git a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java index 86b7ef4aa7d..c056725cad9 100644 --- a/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java +++ b/processing/src/main/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapter.java @@ -226,6 +226,13 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter throw new UnsupportedOperationException("Cannot retrieve metadata from join segment"); } + @Override + public boolean hasBuiltInFilters() + { + return clauses.stream() + .anyMatch(clause -> clause.getJoinType() == JoinType.INNER && !clause.getCondition().isAlwaysTrue()); + } + @Override public boolean canVectorize(@Nullable Filter filter, VirtualColumns virtualColumns, boolean descending) { @@ -343,7 +350,7 @@ public class HashJoinSegmentStorageAdapter implements StorageAdapter return PostJoinCursor.wrap( retVal, VirtualColumns.create(postJoinVirtualColumns), - joinFilterSplit.getJoinTableFilter().isPresent() ? joinFilterSplit.getJoinTableFilter().get() : null + joinFilterSplit.getJoinTableFilter().orElse(null) ); } ).withBaggage(joinablesCloser); diff --git a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java index 26ba16119a0..33199b32864 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/BaseHashJoinSegmentStorageAdapterTest.java @@ -205,10 +205,15 @@ public class BaseHashJoinSegmentStorageAdapterTest * have {@link org.apache.druid.segment.StorageAdapter#makeCursors} called on it. */ protected HashJoinSegmentStorageAdapter makeFactToCountrySegment() + { + return makeFactToCountrySegment(JoinType.LEFT); + } + + protected HashJoinSegmentStorageAdapter makeFactToCountrySegment(JoinType joinType) { return new HashJoinSegmentStorageAdapter( factSegment.asStorageAdapter(), - ImmutableList.of(factToCountryOnIsoCode(JoinType.LEFT)), + ImmutableList.of(factToCountryOnIsoCode(joinType)), null ); } diff --git a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java index 10d048305f8..4ee41a2cef0 100644 --- a/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java +++ b/processing/src/test/java/org/apache/druid/segment/join/HashJoinSegmentStorageAdapterTest.java @@ -2266,4 +2266,101 @@ public class HashJoinSegmentStorageAdapterTest extends BaseHashJoinSegmentStorag Assert.assertEquals(expectedPostJoin, actualPostJoin); } + @Test + public void test_hasBuiltInFiltersForSingleJoinableClauseWithVariousJoinTypes() + { + Assert.assertTrue(makeFactToCountrySegment(JoinType.INNER).hasBuiltInFilters()); + Assert.assertFalse(makeFactToCountrySegment(JoinType.LEFT).hasBuiltInFilters()); + Assert.assertFalse(makeFactToCountrySegment(JoinType.RIGHT).hasBuiltInFilters()); + Assert.assertFalse(makeFactToCountrySegment(JoinType.FULL).hasBuiltInFilters()); + // cross join + Assert.assertFalse( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of( + new JoinableClause( + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, + new IndexedTableJoinable(countriesTable), + JoinType.INNER, + JoinConditionAnalysis.forExpression( + "'true'", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, + ExprMacroTable.nil() + ) + ) + ), + null + ).hasBuiltInFilters() + ); + } + + @Test + public void test_hasBuiltInFiltersForEmptyJoinableClause() + { + Assert.assertFalse( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of(), + null + ).hasBuiltInFilters() + ); + } + + @Test + public void test_hasBuiltInFiltersForMultipleJoinableClausesWithVariousJoinTypes() + { + Assert.assertTrue( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of( + factToRegion(JoinType.INNER), + regionToCountry(JoinType.LEFT) + ), + null + ).hasBuiltInFilters() + ); + + Assert.assertTrue( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of( + factToRegion(JoinType.RIGHT), + regionToCountry(JoinType.INNER), + factToCountryOnNumber(JoinType.FULL) + ), + null + ).hasBuiltInFilters() + ); + + Assert.assertFalse( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of( + factToRegion(JoinType.LEFT), + regionToCountry(JoinType.LEFT) + ), + null + ).hasBuiltInFilters() + ); + + Assert.assertFalse( + new HashJoinSegmentStorageAdapter( + factSegment.asStorageAdapter(), + ImmutableList.of( + factToRegion(JoinType.LEFT), + new JoinableClause( + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, + new IndexedTableJoinable(countriesTable), + JoinType.INNER, + JoinConditionAnalysis.forExpression( + "'true'", + FACT_TO_COUNTRY_ON_ISO_CODE_PREFIX, + ExprMacroTable.nil() + ) + ) + ), + null + ).hasBuiltInFilters() + ); + } } diff --git a/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java b/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java index 94c6c596090..97457b8688e 100644 --- a/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java +++ b/server/src/test/java/org/apache/druid/query/QueryRunnerBasedOnClusteredClientTestBase.java @@ -46,6 +46,7 @@ import org.apache.druid.query.context.ConcurrentResponseContext; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext.Key; import org.apache.druid.query.timeseries.TimeseriesResultValue; +import org.apache.druid.query.topn.TopNQueryConfig; import org.apache.druid.segment.QueryableIndex; import org.apache.druid.segment.generator.GeneratorBasicSchemas; import org.apache.druid.segment.generator.GeneratorSchemaInfo; @@ -108,7 +109,11 @@ public abstract class QueryRunnerBasedOnClusteredClientTestBase protected QueryRunnerBasedOnClusteredClientTestBase() { - conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(CLOSER, USE_PARALLEL_MERGE_POOL_CONFIGURED); + conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate( + CLOSER, + USE_PARALLEL_MERGE_POOL_CONFIGURED, + () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD + ); toolChestWarehouse = new QueryToolChestWarehouse() { diff --git a/server/src/test/java/org/apache/druid/server/QueryStackTests.java b/server/src/test/java/org/apache/druid/server/QueryStackTests.java index 074649f26a6..9e7c234c0c6 100644 --- a/server/src/test/java/org/apache/druid/server/QueryStackTests.java +++ b/server/src/test/java/org/apache/druid/server/QueryStackTests.java @@ -80,6 +80,7 @@ import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; /** * Utilities for creating query-stack objects for tests. @@ -228,20 +229,30 @@ public class QueryStackTests */ public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate(final Closer closer) { - return createQueryRunnerFactoryConglomerate(closer, true); + return createQueryRunnerFactoryConglomerate(closer, true, () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD); } public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate( final Closer closer, - final boolean useParallelMergePoolConfigured - + final Supplier minTopNThresholdSupplier ) { - return createQueryRunnerFactoryConglomerate(closer, - getProcessingConfig( - useParallelMergePoolConfigured, - DruidProcessingConfig.DEFAULT_NUM_MERGE_BUFFERS - ) + return createQueryRunnerFactoryConglomerate(closer, true, minTopNThresholdSupplier); + } + + public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate( + final Closer closer, + final boolean useParallelMergePoolConfigured, + final Supplier minTopNThresholdSupplier + ) + { + return createQueryRunnerFactoryConglomerate( + closer, + getProcessingConfig( + useParallelMergePoolConfigured, + DruidProcessingConfig.DEFAULT_NUM_MERGE_BUFFERS + ), + minTopNThresholdSupplier ); } @@ -249,6 +260,19 @@ public class QueryStackTests final Closer closer, final DruidProcessingConfig processingConfig ) + { + return createQueryRunnerFactoryConglomerate( + closer, + processingConfig, + () -> TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD + ); + } + + public static QueryRunnerFactoryConglomerate createQueryRunnerFactoryConglomerate( + final Closer closer, + final DruidProcessingConfig processingConfig, + final Supplier minTopNThresholdSupplier + ) { final CloseableStupidPool stupidPool = new CloseableStupidPool<>( "TopNQueryRunnerFactory-bufferPool", @@ -308,7 +332,14 @@ public class QueryStackTests TopNQuery.class, new TopNQueryRunnerFactory( stupidPool, - new TopNQueryQueryToolChest(new TopNQueryConfig()), + new TopNQueryQueryToolChest(new TopNQueryConfig() + { + @Override + public int getMinTopNThreshold() + { + return minTopNThresholdSupplier.get(); + } + }), QueryRunnerTestHelper.NOOP_QUERYWATCHER ) ) diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java index c8988cfb4fa..b4c36829197 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/BaseCalciteQueryTest.java @@ -62,6 +62,7 @@ import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.QuerySegmentSpec; import org.apache.druid.query.timeseries.TimeseriesQuery; +import org.apache.druid.query.topn.TopNQueryConfig; import org.apache.druid.segment.column.ColumnHolder; import org.apache.druid.segment.column.ValueType; import org.apache.druid.segment.join.JoinType; @@ -248,6 +249,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase public static QueryRunnerFactoryConglomerate conglomerate; public static Closer resourceCloser; + public static int minTopNThreshold = TopNQueryConfig.DEFAULT_MIN_TOPN_THRESHOLD; @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -444,7 +446,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase public static void setUpClass() { resourceCloser = Closer.create(); - conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser); + conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser, () -> minTopNThreshold); } @AfterClass diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index 9404b037f00..d05b1503394 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -353,6 +353,56 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ); } + @Test + public void testExactTopNOnInnerJoinWithLimit() throws Exception + { + // Adjust topN threshold, so that the topN engine keeps only 1 slot for aggregates, which should be enough + // to compute the query with limit 1. + minTopNThreshold = 1; + Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT); + context.put(PlannerConfig.CTX_KEY_USE_APPROXIMATE_TOPN, false); + testQuery( + "select f1.\"dim4\", sum(\"m1\") from numfoo f1 inner join (\n" + + " select \"dim4\" from numfoo where dim4 <> 'a' group by 1\n" + + ") f2 on f1.\"dim4\" = f2.\"dim4\" group by 1 limit 1", + context, // turn on exact topN + ImmutableList.of( + new TopNQueryBuilder() + .intervals(querySegmentSpec(Filtration.eternity())) + .granularity(Granularities.ALL) + .dimension(new DefaultDimensionSpec("dim4", "_d0")) + .aggregators(new DoubleSumAggregatorFactory("a0", "m1")) + .metric(new DimensionTopNMetricSpec(null, StringComparators.LEXICOGRAPHIC)) + .threshold(1) + .dataSource( + JoinDataSource.create( + new TableDataSource("numfoo"), + new QueryDataSource( + GroupByQuery.builder() + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimFilter(new NotDimFilter(new SelectorDimFilter("dim4", "a", null))) + .setDataSource(new TableDataSource("numfoo")) + .setDimensions(new DefaultDimensionSpec("dim4", "_d0")) + .setContext(context) + .build() + ), + "j0.", + "(\"dim4\" == \"j0._d0\")", + JoinType.INNER, + null, + ExprMacroTable.nil() + ) + ) + .context(context) + .build() + ), + ImmutableList.of( + new Object[]{"b", 15.0} + ) + ); + } + @Test public void testJoinOuterGroupByAndSubqueryHasLimit() throws Exception {