diff --git a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
index 134c9da48d7..94c98f8ae2b 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/JoinableFactoryWrapper.java
@@ -25,11 +25,12 @@ import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Sets;
import com.google.inject.Inject;
+import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.filter.Filter;
import org.apache.druid.query.filter.InDimFilter;
import org.apache.druid.segment.filter.FalseFilter;
-import org.apache.druid.segment.filter.Filters;
+import org.apache.druid.utils.CollectionUtils;
import javax.annotation.Nullable;
import java.util.ArrayList;
@@ -120,7 +121,7 @@ public class JoinableFactoryWrapper
}
/**
- * Converts a join clause into an "in" filter if possible.
+ * Converts a join clause into appropriate filter(s) if possible.
*
* The requirements are:
*
@@ -144,49 +145,44 @@ public class JoinableFactoryWrapper
final Set rightPrefixes
)
{
+ // This optimization kicks in when there is exactly 1 equijoin
+ final List equiConditions = clause.getCondition().getEquiConditions();
if (clause.getJoinType() == JoinType.INNER
&& clause.getCondition().getNonEquiConditions().isEmpty()
- && clause.getCondition().getEquiConditions().size() > 0) {
- final List filters = new ArrayList<>();
- int numValues = maxNumFilterValues;
+ && equiConditions.size() == 1) {
// if the right side columns are required, the clause cannot be fully converted
boolean joinClauseFullyConverted = requiredColumns.stream().noneMatch(clause::includesColumn);
+ final Equality condition = CollectionUtils.getOnlyElement(
+ equiConditions,
+ xse -> new IAE("Expected only one equi condition")
+ );
- for (final Equality condition : clause.getCondition().getEquiConditions()) {
- final String leftColumn = condition.getLeftExpr().getBindingIfIdentifier();
+ final String leftColumn = condition.getLeftExpr().getBindingIfIdentifier();
- if (leftColumn == null) {
- return new JoinClauseToFilterConversion(null, false);
- }
-
- // don't add a filter on any right side table columns. only filter on left base table is supported as of now.
- if (rightPrefixes.stream().anyMatch(leftColumn::startsWith)) {
- joinClauseFullyConverted = false;
- continue;
- }
-
- Joinable.ColumnValuesWithUniqueFlag columnValuesWithUniqueFlag =
- clause.getJoinable().getNonNullColumnValues(condition.getRightColumn(), numValues);
- // For an empty values set, isAllUnique flag will be true only if the column had no non-null values.
- if (columnValuesWithUniqueFlag.getColumnValues().isEmpty()) {
- if (columnValuesWithUniqueFlag.isAllUnique()) {
- return new JoinClauseToFilterConversion(FalseFilter.instance(), true);
- } else {
- joinClauseFullyConverted = false;
- }
- continue;
- }
-
- numValues -= columnValuesWithUniqueFlag.getColumnValues().size();
- filters.add(Filters.toFilter(new InDimFilter(leftColumn, columnValuesWithUniqueFlag.getColumnValues())));
- if (!columnValuesWithUniqueFlag.isAllUnique()) {
- joinClauseFullyConverted = false;
- }
+ if (leftColumn == null) {
+ return new JoinClauseToFilterConversion(null, false);
}
- return new JoinClauseToFilterConversion(Filters.maybeAnd(filters).orElse(null), joinClauseFullyConverted);
- }
+ // don't add a filter on any right side table columns. only filter on left base table is supported as of now.
+ if (rightPrefixes.stream().anyMatch(leftColumn::startsWith)) {
+ return new JoinClauseToFilterConversion(null, false);
+ }
+ Joinable.ColumnValuesWithUniqueFlag columnValuesWithUniqueFlag =
+ clause.getJoinable().getNonNullColumnValues(condition.getRightColumn(), maxNumFilterValues);
+ // For an empty values set, isAllUnique flag will be true only if the column had no non-null values.
+ if (columnValuesWithUniqueFlag.getColumnValues().isEmpty()) {
+ if (columnValuesWithUniqueFlag.isAllUnique()) {
+ return new JoinClauseToFilterConversion(FalseFilter.instance(), true);
+ }
+ return new JoinClauseToFilterConversion(null, false);
+ }
+ final Filter onlyFilter = new InDimFilter(leftColumn, columnValuesWithUniqueFlag.getColumnValues());
+ if (!columnValuesWithUniqueFlag.isAllUnique()) {
+ joinClauseFullyConverted = false;
+ }
+ return new JoinClauseToFilterConversion(onlyFilter, joinClauseFullyConverted);
+ }
return new JoinClauseToFilterConversion(null, false);
}
diff --git a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
index 27933235f97..cf7ced87436 100644
--- a/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
+++ b/processing/src/main/java/org/apache/druid/segment/join/table/IndexedTableJoinable.java
@@ -103,8 +103,7 @@ public class IndexedTableJoinable implements Joinable
}
try (final IndexedTable.Reader reader = table.columnReader(columnPosition)) {
- // Sorted set to encourage "in" filters that result from this method to do dictionary lookups in order.
- // The hopes are that this will improve locality and therefore improve performance.
+ // Use a SortedSet so InDimFilter doesn't need to create its own
final Set allValues = createValuesSet();
boolean allUnique = true;
diff --git a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
index 1ff3664aba4..8ff6d8461f8 100644
--- a/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
+++ b/processing/src/test/java/org/apache/druid/segment/join/JoinableFactoryWrapperTest.java
@@ -84,6 +84,22 @@ public class JoinableFactoryWrapperTest extends NullHandlingTest
RowSignature.builder().add("country", ColumnType.STRING).build()
);
+ private static final InlineDataSource INDEXED_TABLE_DS_THREE_COLS = InlineDataSource.fromIterable(
+ ImmutableList.of(
+ new Object[]{"El Salvador", 1, 1.0},
+ new Object[]{"Mexico", 2, 2.0},
+ new Object[]{"United States", 3, 3.0},
+ new Object[]{"Norway", 4, 4.0},
+ new Object[]{"India", 5, 5.0},
+ new Object[]{"United States", 6, 3.0}
+ ),
+ RowSignature.builder()
+ .add("country", ColumnType.STRING)
+ .add("m1", ColumnType.LONG)
+ .add("m2", ColumnType.DOUBLE)
+ .build()
+ );
+
private static final InlineDataSource NULL_INDEXED_TABLE_DS = InlineDataSource.fromIterable(
ImmutableList.of(
new Object[]{null}
@@ -99,6 +115,14 @@ public class JoinableFactoryWrapperTest extends NullHandlingTest
DateTimes.nowUtc().toString()
);
+ private static final IndexedTable TEST_INDEXED_TABLE_THREE_COLS = new RowBasedIndexedTable<>(
+ INDEXED_TABLE_DS_THREE_COLS.getRowsAsList(),
+ INDEXED_TABLE_DS_THREE_COLS.rowAdapter(),
+ INDEXED_TABLE_DS_THREE_COLS.getRowSignature(),
+ ImmutableSet.of("country", "m1", "m2"),
+ DateTimes.nowUtc().toString()
+ );
+
private static final IndexedTable TEST_NULL_INDEXED_TABLE = new RowBasedIndexedTable<>(
NULL_INDEXED_TABLE_DS.getRowsAsList(),
NULL_INDEXED_TABLE_DS.rowAdapter(),
@@ -625,4 +649,29 @@ public class JoinableFactoryWrapperTest extends NullHandlingTest
conversion
);
}
+
+ @Test
+ public void test_convertJoinsToPartialFiltersMultipleCondtions()
+ {
+ JoinableClause joinableClause = new JoinableClause(
+ "j.",
+ new IndexedTableJoinable(TEST_INDEXED_TABLE_THREE_COLS),
+ JoinType.INNER,
+ JoinConditionAnalysis.forExpression("x == \"j.country\" && y == \"j.m1\"", "j.", ExprMacroTable.nil())
+ );
+ final Pair, List> conversion = JoinableFactoryWrapper.convertJoinsToFilters(
+ ImmutableList.of(joinableClause),
+ ImmutableSet.of("x", "y"),
+ Integer.MAX_VALUE
+ );
+
+ // Optimization does not kick in as there are > 1 equijoins
+ Assert.assertEquals(
+ Pair.of(
+ ImmutableList.of(),
+ ImmutableList.of(joinableClause)
+ ),
+ conversion
+ );
+ }
}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
index da5abc9582a..1f45aebc8b5 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteJoinQueryTest.java
@@ -5530,4 +5530,98 @@ public class CalciteJoinQueryTest extends BaseCalciteQueryTest
return results;
}
}
+
+ @Test
+ public void testJoinsWithTwoConditions()
+ {
+ Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+ testQuery(
+ "SELECT t1.__time, t1.m1\n"
+ + "FROM foo t1\n"
+ + "JOIN (SELECT m1, MAX(__time) as latest_time FROM foo WHERE m1 IN (1,2) GROUP BY m1) t2\n"
+ + "ON t1.m1 = t2.m1 AND t1.__time = t2.latest_time\n",
+ context,
+ ImmutableList.of(
+ newScanQueryBuilder()
+ .dataSource(
+ join(
+ new TableDataSource(CalciteTests.DATASOURCE1),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1))
+ .setDimFilter(in("m1", ImmutableList.of("1", "2"), null))
+ .setDimensions(new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT))
+ .setAggregatorSpecs(aggregators(new LongMaxAggregatorFactory("a0", "__time")))
+ .setContext(context)
+ .build()
+ ),
+ "j0.",
+ "((\"m1\" == \"j0.d0\") && (\"__time\" == \"j0.a0\"))",
+ JoinType.INNER
+ )
+ )
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .columns("__time", "m1")
+ .context(context)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{946684800000L, 1.0f},
+ new Object[]{946771200000L, 2.0f}
+ )
+ );
+ }
+
+ @Test
+ public void testJoinsWithThreeConditions()
+ {
+ Map context = new HashMap<>(QUERY_CONTEXT_DEFAULT);
+ testQuery(
+ "SELECT t1.__time, t1.m1, t1.m2\n"
+ + "FROM foo t1\n"
+ + "JOIN (SELECT m1, m2, MAX(__time) as latest_time FROM foo WHERE m1 IN (1,2) AND m2 IN (1,2) GROUP by m1,m2) t2\n"
+ + "ON t1.m1 = t2.m1 AND t1.m2 = t2.m2 AND t1.__time = t2.latest_time\n",
+ context,
+ ImmutableList.of(
+ newScanQueryBuilder()
+ .dataSource(
+ join(
+ new TableDataSource(CalciteTests.DATASOURCE1),
+ new QueryDataSource(
+ GroupByQuery.builder()
+ .setInterval(querySegmentSpec(Filtration.eternity()))
+ .setGranularity(Granularities.ALL)
+ .setDataSource(new TableDataSource(CalciteTests.DATASOURCE1))
+ .setDimFilter(
+ and(
+ in("m1", ImmutableList.of("1", "2"), null),
+ in("m2", ImmutableList.of("1", "2"), null)
+ )
+ )
+ .setDimensions(
+ new DefaultDimensionSpec("m1", "d0", ColumnType.FLOAT),
+ new DefaultDimensionSpec("m2", "d1", ColumnType.DOUBLE)
+ )
+ .setAggregatorSpecs(aggregators(new LongMaxAggregatorFactory("a0", "__time")))
+ .setContext(context)
+ .build()
+ ),
+ "j0.",
+ "((\"m1\" == \"j0.d0\") && (\"m2\" == \"j0.d1\") && (\"__time\" == \"j0.a0\"))",
+ JoinType.INNER
+ )
+ )
+ .intervals(querySegmentSpec(Filtration.eternity()))
+ .columns("__time", "m1", "m2")
+ .context(context)
+ .build()
+ ),
+ ImmutableList.of(
+ new Object[]{946684800000L, 1.0f, 1.0},
+ new Object[]{946771200000L, 2.0f, 2.0}
+ )
+ );
+ }
}