fix incorrect check of maxSemiJoinRowsInMemory (#6242)

This commit is contained in:
Dayue Gao 2018-08-28 07:28:29 +08:00 committed by Fangjin Yang
parent 4a8b09b6a9
commit 2325844a38
2 changed files with 32 additions and 4 deletions

View File

@ -288,6 +288,8 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
new ArrayList<>(), new ArrayList<>(),
new Accumulator<List<RexNode>, Object[]>() new Accumulator<List<RexNode>, Object[]>()
{ {
int numRows;
@Override @Override
public List<RexNode> accumulate(final List<RexNode> theConditions, final Object[] row) public List<RexNode> accumulate(final List<RexNode> theConditions, final Object[] row)
{ {
@ -301,14 +303,14 @@ public class DruidSemiJoin extends DruidRel<DruidSemiJoin>
} }
final String stringValue = DimensionHandlerUtils.convertObjectToString(value); final String stringValue = DimensionHandlerUtils.convertObjectToString(value);
values.add(stringValue); values.add(stringValue);
if (values.size() > maxSemiJoinRowsInMemory) { }
if (valuess.add(values)) {
if (++numRows > maxSemiJoinRowsInMemory) {
throw new ResourceLimitExceededException( throw new ResourceLimitExceededException(
StringUtils.format("maxSemiJoinRowsInMemory[%,d] exceeded", maxSemiJoinRowsInMemory) StringUtils.format("maxSemiJoinRowsInMemory[%,d] exceeded", maxSemiJoinRowsInMemory)
); );
} }
}
if (valuess.add(values)) {
final List<RexNode> subConditions = new ArrayList<>(); final List<RexNode> subConditions = new ArrayList<>();
for (int i = 0; i < values.size(); i++) { for (int i = 0; i < values.size(); i++) {

View File

@ -39,6 +39,7 @@ import io.druid.query.Query;
import io.druid.query.QueryContexts; import io.druid.query.QueryContexts;
import io.druid.query.QueryDataSource; import io.druid.query.QueryDataSource;
import io.druid.query.QueryRunnerFactoryConglomerate; import io.druid.query.QueryRunnerFactoryConglomerate;
import io.druid.query.ResourceLimitExceededException;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleMaxAggregatorFactory; import io.druid.query.aggregation.DoubleMaxAggregatorFactory;
@ -187,6 +188,13 @@ public class CalciteQueryTest extends CalciteTestBase
return DateTimes.inferTzfromString("America/Los_Angeles"); return DateTimes.inferTzfromString("America/Los_Angeles");
} }
}; };
private static final PlannerConfig PLANNER_CONFIG_SEMI_JOIN_ROWS_LIMIT = new PlannerConfig() {
@Override
public int getMaxSemiJoinRowsInMemory()
{
return 2;
}
};
private static final String LOS_ANGELES = "America/Los_Angeles"; private static final String LOS_ANGELES = "America/Los_Angeles";
@ -4696,6 +4704,24 @@ public class CalciteQueryTest extends CalciteTestBase
); );
} }
@Test
public void testMaxSemiJoinRowsInMemory() throws Exception
{
expectedException.expect(ResourceLimitExceededException.class);
expectedException.expectMessage("maxSemiJoinRowsInMemory[2] exceeded");
testQuery(
PLANNER_CONFIG_SEMI_JOIN_ROWS_LIMIT,
"SELECT COUNT(*)\n"
+ "FROM druid.foo\n"
+ "WHERE SUBSTRING(dim2, 1, 1) IN (\n"
+ " SELECT SUBSTRING(dim1, 1, 1) FROM druid.foo WHERE dim1 <> ''\n"
+ ")\n",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
ImmutableList.of()
);
}
@Test @Test
public void testExplainExactCountDistinctOfSemiJoinResult() throws Exception public void testExplainExactCountDistinctOfSemiJoinResult() throws Exception
{ {