MSQ: Nicer error when sortMerge join falls back to broadcast. (#16002)

* MSQ: Nicer error when sortMerge join falls back to broadcast.

In certain cases, joins run as broadcast even when the user hinted
that they wanted sortMerge. This happens when the sortMerge algorithm
is unable to process the join, because it isn't a direct comparison
between two fields on the LHS and RHS.

When this happens, the error message from BroadcastTablesTooLargeFault
is quite confusing, since it mentions that you should try sortMerge
to fix it. But the user may have already configured sortMerge.

This patch fixes it by having two error messages, based on whether
broadcast join was used as a primary selection or as a fallback selection.

* Style.

* Better message.
This commit is contained in:
Gian Merlino 2024-03-01 13:16:39 -08:00 committed by GitHub
parent ef48aceff8
commit 8d3ed31015
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 15 deletions

View File

@ -20,11 +20,14 @@
package org.apache.druid.msq.indexing.error; package org.apache.druid.msq.indexing.error;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm; import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.PlannerContext;
import javax.annotation.Nullable;
import java.util.Objects; import java.util.Objects;
@JsonTypeName(BroadcastTablesTooLargeFault.CODE) @JsonTypeName(BroadcastTablesTooLargeFault.CODE)
@ -34,19 +37,18 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
private final long maxBroadcastTablesSize; private final long maxBroadcastTablesSize;
@Nullable
private final JoinAlgorithm configuredJoinAlgorithm;
@JsonCreator @JsonCreator
public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize) public BroadcastTablesTooLargeFault(
@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize,
@Nullable @JsonProperty("configuredJoinAlgorithm") final JoinAlgorithm configuredJoinAlgorithm
)
{ {
super( super(CODE, makeMessage(maxBroadcastTablesSize, configuredJoinAlgorithm));
CODE,
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = %d bytes). "
+ "Increase available memory, or set %s: %s in query context to use a shuffle-based join.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString()
);
this.maxBroadcastTablesSize = maxBroadcastTablesSize; this.maxBroadcastTablesSize = maxBroadcastTablesSize;
this.configuredJoinAlgorithm = configuredJoinAlgorithm;
} }
@JsonProperty @JsonProperty
@ -55,6 +57,14 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
return maxBroadcastTablesSize; return maxBroadcastTablesSize;
} }
@Nullable
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public JoinAlgorithm getConfiguredJoinAlgorithm()
{
return configuredJoinAlgorithm;
}
@Override @Override
public boolean equals(Object o) public boolean equals(Object o)
{ {
@ -68,12 +78,38 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
return false; return false;
} }
BroadcastTablesTooLargeFault that = (BroadcastTablesTooLargeFault) o; BroadcastTablesTooLargeFault that = (BroadcastTablesTooLargeFault) o;
return maxBroadcastTablesSize == that.maxBroadcastTablesSize; return maxBroadcastTablesSize == that.maxBroadcastTablesSize
&& configuredJoinAlgorithm == that.configuredJoinAlgorithm;
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(super.hashCode(), maxBroadcastTablesSize); return Objects.hash(super.hashCode(), maxBroadcastTablesSize, configuredJoinAlgorithm);
}
private static String makeMessage(final long maxBroadcastTablesSize, final JoinAlgorithm configuredJoinAlgorithm)
{
if (configuredJoinAlgorithm == null || configuredJoinAlgorithm == JoinAlgorithm.BROADCAST) {
return StringUtils.format(
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = [%,d] bytes). "
+ "Increase available memory, or set [%s: %s] in query context to use a shuffle-based join.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString()
);
} else {
return StringUtils.format(
"Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = [%,d] bytes). "
+ "Try increasing available memory. "
+ "This query is using broadcast JOIN even though [%s: %s] is set in query context, because the configured "
+ "join algorithm does not support the join condition.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
configuredJoinAlgorithm.toString()
);
}
} }
} }

View File

@ -42,11 +42,14 @@ import org.apache.druid.query.Query;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor; import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.SegmentReference; import org.apache.druid.segment.SegmentReference;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -242,7 +245,15 @@ public class BroadcastJoinSegmentMapFnProcessor implements FrameProcessor<Functi
memoryUsed += frame.numBytes(); memoryUsed += frame.numBytes();
if (memoryUsed > memoryReservedForBroadcastJoin) { if (memoryUsed > memoryReservedForBroadcastJoin) {
throw new MSQException(new BroadcastTablesTooLargeFault(memoryReservedForBroadcastJoin)); throw new MSQException(
new BroadcastTablesTooLargeFault(
memoryReservedForBroadcastJoin,
Optional.ofNullable(query)
.map(q -> q.context().getString(PlannerContext.CTX_SQL_JOIN_ALGORITHM))
.map(JoinAlgorithm::fromString)
.orElse(null)
)
);
} }
addFrame(channelNumber, frame); addFrame(channelNumber, frame);

View File

@ -26,6 +26,7 @@ import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -49,7 +50,8 @@ public class MSQFaultSerdeTest
@Test @Test
public void testFaultSerde() throws IOException public void testFaultSerde() throws IOException
{ {
assertFaultSerde(new BroadcastTablesTooLargeFault(10)); assertFaultSerde(new BroadcastTablesTooLargeFault(10, null));
assertFaultSerde(new BroadcastTablesTooLargeFault(10, JoinAlgorithm.SORT_MERGE));
assertFaultSerde(CanceledFault.INSTANCE); assertFaultSerde(CanceledFault.INSTANCE);
assertFaultSerde(new CannotParseExternalDataFault("the message")); assertFaultSerde(new CannotParseExternalDataFault("the message"));
assertFaultSerde(new ColumnTypeNotSupportedFault("the column", null)); assertFaultSerde(new ColumnTypeNotSupportedFault("the column", null));

View File

@ -19,6 +19,7 @@
package org.apache.druid.msq.querykit; package org.apache.druid.msq.querykit;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import it.unimi.dsi.fastutil.ints.Int2IntMap; import it.unimi.dsi.fastutil.ints.Int2IntMap;
@ -42,12 +43,17 @@ import org.apache.druid.msq.indexing.error.MSQException;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.segment.QueryableIndexStorageAdapter; import org.apache.druid.segment.QueryableIndexStorageAdapter;
import org.apache.druid.segment.StorageAdapter; import org.apache.druid.segment.StorageAdapter;
import org.apache.druid.segment.TestIndex; import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.join.JoinConditionAnalysis; import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType; import org.apache.druid.segment.join.JoinType;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.testing.InitializedNullHandlingTest; import org.apache.druid.testing.InitializedNullHandlingTest;
import org.easymock.EasyMock;
import org.hamcrest.CoreMatchers; import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert; import org.hamcrest.MatcherAssert;
import org.junit.Assert; import org.junit.Assert;
@ -232,7 +238,59 @@ public class BroadcastJoinSegmentMapFnProcessorTest extends InitializedNullHandl
} }
); );
Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000), e.getFault()); Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000, null), e.getFault());
}
/**
* Like {@link #testBuildTableMemoryLimit()}, but with {@link JoinAlgorithm#SORT_MERGE} configured, so we can
* verify we get a better error message.
*/
@Test
public void testBuildTableMemoryLimitWithSortMergeConfigured() throws IOException
{
final Int2IntMap sideStageChannelNumberMap = new Int2IntOpenHashMap();
sideStageChannelNumberMap.put(0, 0);
final List<ReadableFrameChannel> channels = new ArrayList<>();
channels.add(new ReadableFileFrameChannel(FrameFile.open(testDataFile1, ByteTracker.unboundedTracker())));
final List<FrameReader> channelReaders = new ArrayList<>();
channelReaders.add(frameReader1);
// Query: used only to retrieve configured join from context
final Query<?> mockQuery = EasyMock.mock(Query.class);
EasyMock.expect(mockQuery.context()).andReturn(
QueryContext.of(
ImmutableMap.of(
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.getId()
)
)
);
EasyMock.replay(mockQuery);
final BroadcastJoinSegmentMapFnProcessor broadcastJoinHelper = new BroadcastJoinSegmentMapFnProcessor(
mockQuery,
sideStageChannelNumberMap,
channels,
channelReaders,
100_000 // Low memory limit; we will hit this
);
Assert.assertEquals(ImmutableSet.of(0), broadcastJoinHelper.getSideChannelNumbers());
final MSQException e = Assert.assertThrows(
MSQException.class,
() -> {
boolean doneReading = false;
while (!doneReading) {
final IntSet readableInputs = new IntOpenHashSet(new int[]{0});
doneReading = broadcastJoinHelper.buildBroadcastTablesIncrementally(readableInputs);
}
}
);
Assert.assertEquals(new BroadcastTablesTooLargeFault(100_000, JoinAlgorithm.SORT_MERGE), e.getFault());
EasyMock.verify(mockQuery);
} }
/** /**