Deal with potential cardinality estimate being negative and add logging to hash determine partitions phase (#12443)

* Deal with potential cardinality estimate being negative and add logging

* Fix typo in name

* Refine and minimize logging

* Make it info based on code review

* Create a named constant for the magic number
This commit is contained in:
Agustin Gonzalez 2022-05-20 10:51:06 -07:00 committed by GitHub
parent f9bdb3b236
commit c236227905
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 124 additions and 2 deletions

View File

@ -83,6 +83,7 @@ import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.Consumes;
@ -129,6 +130,16 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
private static final String TASK_PHASE_FAILURE_MSG = "Failed in phase[%s]. See task logs for details.";
// Sometimes Druid estimates one shard for hash partitioning despite conditions
// indicating that there ought to be more than one. We have not been able to
// reproduce but looking at the code around where the following constant is used one
// possibility is that the sketch's estimate is negative. If that case happens
// code has been added to log it and to set the estimate to the value of the
// following constant. It is not necessary to parametize this value since if this
// happens it is a bug and the new logging may now provide some evidence to reproduce
// and fix
private static final long DEFAULT_NUM_SHARDS_WHEN_ESTIMATE_GOES_NEGATIVE = 7L;
private final ParallelIndexIngestionSpec ingestionSchema;
/**
* Base name for the {@link SubTaskSpec} ID.
@ -703,6 +714,10 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
cardinalityRunner.getReports().values(),
effectiveMaxRowsPerSegment
);
// This is for potential debugging in case we suspect bad estimation of cardinalities etc,
LOG.debug("intervalToNumShards: %s", intervalToNumShards.toString());
} else {
intervalToNumShards = CollectionUtils.mapValues(
mergeCardinalityReports(cardinalityRunner.getReports().values()),
@ -901,13 +916,40 @@ public class ParallelIndexSupervisorTask extends AbstractBatchIndexTask implemen
{
// aggregate all the sub-reports
Map<Interval, Union> finalCollectors = mergeCardinalityReports(reports);
return computeIntervalToNumShards(maxRowsPerSegment, finalCollectors);
}
@Nonnull
@VisibleForTesting
static Map<Interval, Integer> computeIntervalToNumShards(
int maxRowsPerSegment,
Map<Interval, Union> finalCollectors
)
{
return CollectionUtils.mapValues(
finalCollectors,
union -> {
final double estimatedCardinality = union.getEstimate();
// determine numShards based on maxRowsPerSegment and the cardinality
final long estimatedNumShards = Math.round(estimatedCardinality / maxRowsPerSegment);
final long estimatedNumShards;
if (estimatedCardinality <= 0) {
estimatedNumShards = DEFAULT_NUM_SHARDS_WHEN_ESTIMATE_GOES_NEGATIVE;
LOG.warn("Estimated cardinality for union of estimates is zero or less: %.2f, setting num shards to %d",
estimatedCardinality, estimatedNumShards
);
} else {
// determine numShards based on maxRowsPerSegment and the cardinality
estimatedNumShards = Math.round(estimatedCardinality / maxRowsPerSegment);
}
LOG.info("estimatedNumShards %d given estimated cardinality %.2f and maxRowsPerSegment %d",
estimatedNumShards, estimatedCardinality, maxRowsPerSegment
);
// We have seen this before in the wild in situations where more shards should have been created,
// log it if it happens with some information & context
if (estimatedNumShards == 1) {
LOG.info("estimatedNumShards is ONE (%d) given estimated cardinality %.2f and maxRowsPerSegment %d",
estimatedNumShards, estimatedCardinality, maxRowsPerSegment
);
}
try {
return Math.max(Math.toIntExact(estimatedNumShards), 1);
}

View File

@ -23,26 +23,38 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.datasketches.hll.HllSketch;
import org.apache.datasketches.hll.Union;
import org.apache.druid.hll.HyperLogLogCollector;
import org.apache.druid.indexing.common.task.IndexTask;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.testing.junit.LoggerCaptureRule;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.Mockito;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.mockito.Mockito.mock;
public class DimensionCardinalityReportTest
{
private static final ObjectMapper OBJECT_MAPPER = ParallelIndexTestingFactory.createObjectMapper();
private DimensionCardinalityReport target;
@Rule
public final LoggerCaptureRule logger = new LoggerCaptureRule(ParallelIndexSupervisorTask.class);
@Before
public void setup()
{
@ -293,4 +305,72 @@ public class DimensionCardinalityReportTest
intervalToNumShards
);
}
@Test
public void testSupervisorDetermineNegativeNumShardsFromCardinalityReport()
{
logger.clearLogEvents();
Union negativeUnion = mock(Union.class);
Mockito.when(negativeUnion.getEstimate()).thenReturn(-1.0);
Interval interval = Intervals.of("2001-01-01/P1D");
Map<Interval, Union> intervalToUnion = ImmutableMap.of(interval, negativeUnion);
Map<Interval, Integer> intervalToNumShards =
ParallelIndexSupervisorTask.computeIntervalToNumShards(10, intervalToUnion);
Assert.assertEquals(new Integer(7), intervalToNumShards.get(interval));
List<LogEvent> loggingEvents = logger.getLogEvents();
String expectedLogMessage =
"Estimated cardinality for union of estimates is zero or less: -1.00, setting num shards to 7";
Assert.assertTrue(
"Logging events: " + loggingEvents,
loggingEvents.stream()
.anyMatch(l ->
l.getLevel().equals(Level.WARN)
&& l.getMessage()
.getFormattedMessage()
.equals(expectedLogMessage)
)
);
}
@Test
public void testSupervisorDeterminePositiveNumShardsFromCardinalityReport()
{
Union union = mock(Union.class);
Mockito.when(union.getEstimate()).thenReturn(24.0);
Interval interval = Intervals.of("2001-01-01/P1D");
Map<Interval, Union> intervalToUnion = ImmutableMap.of(interval, union);
Map<Interval, Integer> intervalToNumShards =
ParallelIndexSupervisorTask.computeIntervalToNumShards(6, intervalToUnion);
Assert.assertEquals(new Integer(4), intervalToNumShards.get(interval));
}
@Test
public void testSupervisorDeterminePositiveOneShardFromCardinalityReport()
{
logger.clearLogEvents();
Union union = mock(Union.class);
Mockito.when(union.getEstimate()).thenReturn(24.0);
Interval interval = Intervals.of("2001-01-01/P1D");
Map<Interval, Union> intervalToUnion = ImmutableMap.of(interval, union);
Map<Interval, Integer> intervalToNumShards =
ParallelIndexSupervisorTask.computeIntervalToNumShards(24, intervalToUnion);
Assert.assertEquals(new Integer(1), intervalToNumShards.get(interval));
List<LogEvent> loggingEvents = logger.getLogEvents();
String expectedLogMessage =
"estimatedNumShards is ONE (1) given estimated cardinality 24.00 and maxRowsPerSegment 24";
Assert.assertTrue(
"Logging events: " + loggingEvents,
loggingEvents.stream()
.anyMatch(l ->
l.getLevel().equals(Level.INFO)
&& l.getMessage()
.getFormattedMessage()
.equals(expectedLogMessage)
)
);
}
}