diff --git a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java index f54fd51c68e..da50c04c8fb 100644 --- a/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java +++ b/sql/src/main/java/io/druid/sql/calcite/rel/DruidSemiJoin.java @@ -288,6 +288,8 @@ public class DruidSemiJoin extends DruidRel new ArrayList<>(), new Accumulator, Object[]>() { + int numRows; + @Override public List accumulate(final List theConditions, final Object[] row) { @@ -301,14 +303,14 @@ public class DruidSemiJoin extends DruidRel } final String stringValue = DimensionHandlerUtils.convertObjectToString(value); values.add(stringValue); - if (values.size() > maxSemiJoinRowsInMemory) { + } + + if (valuess.add(values)) { + if (++numRows > maxSemiJoinRowsInMemory) { throw new ResourceLimitExceededException( StringUtils.format("maxSemiJoinRowsInMemory[%,d] exceeded", maxSemiJoinRowsInMemory) ); } - } - - if (valuess.add(values)) { final List subConditions = new ArrayList<>(); for (int i = 0; i < values.size(); i++) { diff --git a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java index d8b01567ba6..d8cdd14603a 100644 --- a/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/io/druid/sql/calcite/CalciteQueryTest.java @@ -39,6 +39,7 @@ import io.druid.query.Query; import io.druid.query.QueryContexts; import io.druid.query.QueryDataSource; import io.druid.query.QueryRunnerFactoryConglomerate; +import io.druid.query.ResourceLimitExceededException; import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.DoubleMaxAggregatorFactory; @@ -187,6 +188,13 @@ public class CalciteQueryTest extends CalciteTestBase return DateTimes.inferTzfromString("America/Los_Angeles"); } }; + private static final PlannerConfig PLANNER_CONFIG_SEMI_JOIN_ROWS_LIMIT = new PlannerConfig() { + @Override + public int getMaxSemiJoinRowsInMemory() + { + return 2; + } + }; private static final String LOS_ANGELES = "America/Los_Angeles"; @@ -4696,6 +4704,24 @@ public class CalciteQueryTest extends CalciteTestBase ); } + @Test + public void testMaxSemiJoinRowsInMemory() throws Exception + { + expectedException.expect(ResourceLimitExceededException.class); + expectedException.expectMessage("maxSemiJoinRowsInMemory[2] exceeded"); + testQuery( + PLANNER_CONFIG_SEMI_JOIN_ROWS_LIMIT, + "SELECT COUNT(*)\n" + + "FROM druid.foo\n" + + "WHERE SUBSTRING(dim2, 1, 1) IN (\n" + + " SELECT SUBSTRING(dim1, 1, 1) FROM druid.foo WHERE dim1 <> ''\n" + + ")\n", + CalciteTests.REGULAR_USER_AUTH_RESULT, + ImmutableList.of(), + ImmutableList.of() + ); + } + @Test public void testExplainExactCountDistinctOfSemiJoinResult() throws Exception {