IncrementalIndex#add is no longer thread-safe. (#15697)

* IncrementalIndex#add is no longer thread-safe.

Following #14866, there is no longer a reason for IncrementalIndex#add
to be thread-safe.

It turns out it already was not using its selectors in a thread-safe way,
as exposed by #15615 making `testMultithreadAddFactsUsingExpressionAndJavaScript`
in `IncrementalIndexIngestionTest` flaky. Note that this problem isn't
new: Strings have been stored in the dimension selectors for some time,
but we didn't have a test that checked for that case; we only have
this test that checks for concurrent adds involving numeric selectors.

At any rate, this patch changes OnheapIncrementalIndex to no longer try
to offer a thread-safe "add" method. It also improves performance a bit
by adding a row ID supplier to the selectors it uses to read InputRows,
meaning that it can get the benefit of caching values inside the selectors.

This patch also:

1) Adds synchronization to HyperUniquesAggregator and CardinalityAggregator,
   which the similar datasketches versions already have. This is done to
   help them adhere to the contract of Aggregator: concurrent calls to
   "aggregate" and "get" must be thread-safe.

2) Updates OnHeapIncrementalIndexBenchmark to use JMH and moves it to the
   druid-benchmarks module.

* Spelling.

* Changes from static analysis.

* Fix javadoc.
This commit is contained in:
Gian Merlino 2024-01-18 03:45:22 -08:00 committed by GitHub
parent 764f41d959
commit 792e5c58e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 502 additions and 751 deletions

View File

@ -0,0 +1,335 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.benchmark.indexing;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.FinalizeResultsQueryRunner;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.Segment;
import org.apache.druid.segment.incremental.IncrementalIndex;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.incremental.IndexSizeExceededException;
import org.apache.druid.segment.incremental.OnheapIncrementalIndex;
import org.joda.time.Interval;
import org.junit.Assert;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Benchmark for {@link OnheapIncrementalIndex} doing queries and adds simultaneously.
*/
@State(Scope.Benchmark)
@Fork(value = 1)
@Warmup(iterations = 3)
@Measurement(iterations = 5)
public class OnheapIncrementalIndexBenchmark
{
static final int DIMENSION_COUNT = 5;
static {
NullHandling.initializeForTests();
}
/**
* Number of index and query tasks.
*/
private final int taskCount = 30;
/**
* Number of elements to add for each index task.
*/
private final int elementsPerAddTask = 1 << 15;
/**
* Number of query tasks to run simultaneously.
*/
private final int queryThreads = 4;
private AggregatorFactory[] factories;
private IncrementalIndex incrementalIndex;
private ListeningExecutorService indexExecutor;
private ListeningExecutorService queryExecutor;
private static MapBasedInputRow getLongRow(long timestamp, int rowID, int dimensionCount)
{
List<String> dimensionList = new ArrayList<String>(dimensionCount);
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
for (int i = 0; i < dimensionCount; i++) {
String dimName = StringUtils.format("Dim_%d", i);
dimensionList.add(dimName);
builder.put(dimName, Integer.valueOf(rowID).longValue());
}
return new MapBasedInputRow(timestamp, dimensionList, builder.build());
}
@Setup(Level.Trial)
public void setupFactories()
{
final ArrayList<AggregatorFactory> ingestAggregatorFactories = new ArrayList<>(DIMENSION_COUNT + 1);
ingestAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < DIMENSION_COUNT; ++i) {
ingestAggregatorFactories.add(
new LongSumAggregatorFactory(
StringUtils.format("sumResult%s", i),
StringUtils.format("Dim_%s", i)
)
);
ingestAggregatorFactories.add(
new DoubleSumAggregatorFactory(
StringUtils.format("doubleSumResult%s", i),
StringUtils.format("Dim_%s", i)
)
);
}
factories = ingestAggregatorFactories.toArray(new AggregatorFactory[0]);
}
@Setup(Level.Trial)
public void setupExecutors()
{
indexExecutor = MoreExecutors.listeningDecorator(
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("index-executor-%d")
.setPriority(Thread.MIN_PRIORITY)
.build()
)
);
queryExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
queryThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("query-executor-%d")
.build()
)
);
}
@Setup(Level.Invocation)
public void setupIndex()
throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException
{
final Constructor<? extends OnheapIncrementalIndex> constructor =
OnheapIncrementalIndex.class.getDeclaredConstructor(
IncrementalIndexSchema.class,
int.class,
long.class,
boolean.class,
boolean.class
);
constructor.setAccessible(true);
this.incrementalIndex =
constructor.newInstance(
new IncrementalIndexSchema.Builder().withMetrics(factories).build(),
elementsPerAddTask * taskCount,
1_000_000_000L,
false,
false
);
}
@TearDown(Level.Invocation)
public void tearDownIndex()
{
incrementalIndex.close();
incrementalIndex = null;
}
@TearDown(Level.Trial)
public void tearDownExecutors() throws InterruptedException
{
indexExecutor.shutdown();
queryExecutor.shutdown();
if (!indexExecutor.awaitTermination(1, TimeUnit.MINUTES)) {
throw new ISE("Could not shut down indexExecutor");
}
if (!queryExecutor.awaitTermination(1, TimeUnit.MINUTES)) {
throw new ISE("Could not shut down queryExecutor");
}
indexExecutor = null;
queryExecutor = null;
}
@Benchmark
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
public void concurrentAddRead() throws InterruptedException, ExecutionException
{
final ArrayList<AggregatorFactory> queryAggregatorFactories = new ArrayList<>(DIMENSION_COUNT + 1);
queryAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < DIMENSION_COUNT; ++i) {
queryAggregatorFactories.add(
new LongSumAggregatorFactory(
StringUtils.format("sumResult%s", i),
StringUtils.format("sumResult%s", i)
)
);
queryAggregatorFactories.add(
new DoubleSumAggregatorFactory(
StringUtils.format("doubleSumResult%s", i),
StringUtils.format("doubleSumResult%s", i)
)
);
}
final long timestamp = System.currentTimeMillis();
final Interval queryInterval = Intervals.of("1900-01-01T00:00:00Z/2900-01-01T00:00:00Z");
final List<ListenableFuture<?>> indexFutures = new ArrayList<>();
final List<ListenableFuture<?>> queryFutures = new ArrayList<>();
final Segment incrementalIndexSegment = new IncrementalIndexSegment(incrementalIndex, null);
final QueryRunnerFactory factory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
);
final AtomicInteger currentlyRunning = new AtomicInteger(0);
final AtomicBoolean concurrentlyRan = new AtomicBoolean(false);
final AtomicBoolean someoneRan = new AtomicBoolean(false);
for (int j = 0; j < taskCount; j++) {
indexFutures.add(
indexExecutor.submit(
() -> {
currentlyRunning.incrementAndGet();
try {
for (int i = 0; i < elementsPerAddTask; i++) {
incrementalIndex.add(getLongRow(timestamp + i, 1, DIMENSION_COUNT));
}
}
catch (IndexSizeExceededException e) {
throw new RuntimeException(e);
}
currentlyRunning.decrementAndGet();
someoneRan.set(true);
}
)
);
queryFutures.add(
queryExecutor.submit(
() -> {
QueryRunner<Result<TimeseriesResultValue>> runner =
new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(Granularities.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
List<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
for (Result<TimeseriesResultValue> result : results) {
if (someoneRan.get()) {
Assert.assertTrue(result.getValue().getDoubleMetric("doubleSumResult0") > 0);
}
}
if (currentlyRunning.get() > 0) {
concurrentlyRan.set(true);
}
}
)
);
}
List<ListenableFuture<?>> allFutures = new ArrayList<>(queryFutures.size() + indexFutures.size());
allFutures.addAll(queryFutures);
allFutures.addAll(indexFutures);
Futures.allAsList(allFutures).get();
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(Granularities.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
List<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
final int expectedVal = elementsPerAddTask * taskCount;
for (Result<TimeseriesResultValue> result : results) {
Assert.assertEquals(elementsPerAddTask, result.getValue().getLongMetric("rows").intValue());
for (int i = 0; i < DIMENSION_COUNT; ++i) {
Assert.assertEquals(
StringUtils.format("Failed long sum on dimension %d", i),
expectedVal,
result.getValue().getLongMetric(StringUtils.format("sumResult%s", i)).intValue()
);
Assert.assertEquals(
StringUtils.format("Failed double sum on dimension %d", i),
expectedVal,
result.getValue().getDoubleMetric(StringUtils.format("doubleSumResult%s", i)).intValue()
);
}
}
}
}

View File

@ -293,7 +293,7 @@ public class InputRowSerde
public static SerializeResult toBytes(
final Map<String, IndexSerdeTypeHelper> typeHelperMap,
final InputRow row,
AggregatorFactory[] aggs
final AggregatorFactory[] aggs
)
{
try {
@ -323,14 +323,15 @@ public class InputRowSerde
}
//writing all metrics
Supplier<InputRow> supplier = () -> row;
WritableUtils.writeVInt(out, aggs.length);
for (AggregatorFactory aggFactory : aggs) {
String k = aggFactory.getName();
writeString(k, out);
final IncrementalIndex.InputRowHolder holder = new IncrementalIndex.InputRowHolder();
holder.set(row);
try (Aggregator agg = aggFactory.factorize(
IncrementalIndex.makeColumnSelectorFactory(VirtualColumns.EMPTY, aggFactory, supplier)
IncrementalIndex.makeColumnSelectorFactory(VirtualColumns.EMPTY, holder, aggFactory)
)) {
try {
agg.aggregate();

View File

@ -83,7 +83,7 @@ public class CardinalityAggregator implements Aggregator
}
@Override
public void aggregate()
public synchronized void aggregate()
{
if (byRow) {
hashRow(selectorPluses, collector);
@ -93,10 +93,10 @@ public class CardinalityAggregator implements Aggregator
}
@Override
public Object get()
public synchronized Object get()
{
// Workaround for non-thread-safe use of HyperLogLogCollector.
// OnheapIncrementalIndex has a penchant for calling "aggregate" and "get" simultaneously.
// Must make a new collector duplicating the underlying buffer to ensure the object from "get" is usable
// in a thread-safe manner.
return HyperLogLogCollector.makeCollectorSharingStorage(collector);
}

View File

@ -39,7 +39,7 @@ public class HyperUniquesAggregator implements Aggregator
}
@Override
public void aggregate()
public synchronized void aggregate()
{
Object object = selector.getObject();
if (object == null) {
@ -53,13 +53,13 @@ public class HyperUniquesAggregator implements Aggregator
@Nullable
@Override
public Object get()
public synchronized Object get()
{
if (collector == null) {
return null;
}
// Workaround for non-thread-safe use of HyperLogLogCollector.
// OnheapIncrementalIndex has a penchant for calling "aggregate" and "get" simultaneously.
// Must make a new collector duplicating the underlying buffer to ensure the object from "get" is usable
// in a thread-safe manner.
return HyperLogLogCollector.makeCollectorSharingStorage(collector);
}

View File

@ -61,10 +61,10 @@ public class RowBasedColumnSelectorFactory<T> implements ColumnSelectorFactory
private final boolean useStringValueOfNullInLists;
/**
* Package-private constructor for {@link RowBasedCursor}. Allows passing in a rowIdSupplier, which enables
* Full constructor for {@link RowBasedCursor}. Allows passing in a rowIdSupplier, which enables
* column value reuse optimizations.
*/
RowBasedColumnSelectorFactory(
public RowBasedColumnSelectorFactory(
final Supplier<T> rowSupplier,
@Nullable final RowIdSupplier rowIdSupplier,
final RowAdapter<T> adapter,

View File

@ -21,8 +21,8 @@ package org.apache.druid.segment.incremental;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
@ -94,26 +94,40 @@ import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* In-memory, row-based data structure used to hold data during ingestion. Realtime tasks query this index using
* {@link IncrementalIndexStorageAdapter}.
*
* Concurrency model: {@link #add(InputRow)} and {@link #add(InputRow, boolean)} are not thread-safe, and must be
* called from a single thread or externally synchronized. However, the methods that support
* {@link IncrementalIndexStorageAdapter} are thread-safe, and may be called concurrently with each other, and with
* the "add" methods. This concurrency model supports real-time queries of the data in the index.
*/
public abstract class IncrementalIndex implements Iterable<Row>, Closeable, ColumnInspector
{
/**
* Column selector used at ingestion time for inputs to aggregators.
*
* @param agg the aggregator
* @param in ingestion-time input row supplier
* @param virtualColumns virtual columns
* @param inputRowHolder ingestion-time input row holder
* @param agg the aggregator, or null to make a generic aggregator. Only required if the agg has
* {@link AggregatorFactory#getIntermediateType()} as {@link ValueType#COMPLEX}, because
* in this case we need to do some magic to ensure the correct values show up.
*
* @return column selector factory
*/
public static ColumnSelectorFactory makeColumnSelectorFactory(
final VirtualColumns virtualColumns,
final AggregatorFactory agg,
final Supplier<InputRow> in
final InputRowHolder inputRowHolder,
@Nullable final AggregatorFactory agg
)
{
// we use RowSignature.empty() because ColumnInspector here should be the InputRow schema, not the
// IncrementalIndex schema, because we are reading values from the InputRow
final RowBasedColumnSelectorFactory<InputRow> baseSelectorFactory = RowBasedColumnSelectorFactory.create(
final RowBasedColumnSelectorFactory<InputRow> baseSelectorFactory = new RowBasedColumnSelectorFactory<>(
inputRowHolder::getRow,
inputRowHolder::getRowId,
RowAdapters.standardRow(),
in,
RowSignature.empty(),
true,
true
@ -126,7 +140,7 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
{
final ColumnValueSelector selector = baseSelectorFactory.makeColumnValueSelector(column);
if (!agg.getIntermediateType().is(ValueType.COMPLEX)) {
if (agg == null || !agg.getIntermediateType().is(ValueType.COMPLEX)) {
return selector;
} else {
// Wrap selector in a special one that uses ComplexMetricSerde to modify incoming objects.
@ -176,13 +190,13 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
public Object getObject()
{
// Here is where the magic happens: read from "in" directly, don't go through the normal "selector".
return extractor.extractValue(in.get(), column, agg);
return extractor.extractValue(inputRowHolder.getRow(), column, agg);
}
@Override
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
inspector.visit("in", in);
inspector.visit("inputRowHolder", inputRowHolder);
inspector.visit("selector", selector);
inspector.visit("extractor", extractor);
}
@ -230,13 +244,10 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
private final boolean useSchemaDiscovery;
// This is modified on add() in a critical section.
private final ThreadLocal<InputRow> in = new ThreadLocal<>();
private final Supplier<InputRow> rowSupplier = in::get;
private final InputRowHolder inputRowHolder = new InputRowHolder();
private volatile DateTime maxIngestedEventTime;
/**
* @param incrementalIndexSchema the schema to use for incremental index
* @param preserveExistingMetrics When set to true, for any row that already has metric
@ -277,7 +288,7 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
this.rollup
);
initAggs(metrics, rowSupplier);
initAggs(metrics, inputRowHolder);
for (AggregatorFactory metric : metrics) {
MetricDesc metricDesc = new MetricDesc(metricDescs.size(), metric);
@ -333,15 +344,13 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
protected abstract void initAggs(
AggregatorFactory[] metrics,
Supplier<InputRow> rowSupplier
InputRowHolder rowSupplier
);
// Note: This method needs to be thread safe.
// Note: This method does not need to be thread safe.
protected abstract AddToFactsResult addToFacts(
InputRow row,
IncrementalIndexRow key,
ThreadLocal<InputRow> rowContainer,
Supplier<InputRow> rowSupplier,
InputRowHolder inputRowHolder,
boolean skipMaxRowsInMemoryCheck
) throws IndexSizeExceededException;
@ -412,6 +421,34 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
}
}
public static class InputRowHolder
{
@Nullable
private InputRow row;
private long rowId = -1;
public void set(final InputRow row)
{
this.row = row;
this.rowId++;
}
public void unset()
{
this.row = null;
}
public InputRow getRow()
{
return Preconditions.checkNotNull(row, "row");
}
public long getRowId()
{
return rowId;
}
}
public boolean isRollup()
{
return rollup;
@ -474,14 +511,14 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
/**
* Adds a new row. The row might correspond with another row that already exists, in which case this will
* update that row instead of inserting a new one.
* <p>
* <p>
* Calls to add() are thread safe.
* <p>
*
* Not thread-safe.
*
* @param row the row of data to add
*
* @return the number of rows in the data set after adding the InputRow. If any parse failure occurs, a {@link ParseException} is returned in {@link IncrementalIndexAddResult}.
*
* @throws IndexSizeExceededException this exception is thrown once it reaches max rows limit and skipMaxRowsInMemoryCheck is set to false.
*/
public IncrementalIndexAddResult add(InputRow row) throws IndexSizeExceededException
{
@ -491,25 +528,24 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
/**
* Adds a new row. The row might correspond with another row that already exists, in which case this will
* update that row instead of inserting a new one.
* <p>
* <p>
* Calls to add() are thread safe.
* <p>
*
* Not thread-safe.
*
* @param row the row of data to add
* @param skipMaxRowsInMemoryCheck whether or not to skip the check of rows exceeding the max rows limit
* @param skipMaxRowsInMemoryCheck whether or not to skip the check of rows exceeding the max rows or bytes limit
*
* @return the number of rows in the data set after adding the InputRow. If any parse failure occurs, a {@link ParseException} is returned in {@link IncrementalIndexAddResult}.
*
* @throws IndexSizeExceededException this exception is thrown once it reaches max rows limit and skipMaxRowsInMemoryCheck is set to false.
*/
public IncrementalIndexAddResult add(InputRow row, boolean skipMaxRowsInMemoryCheck)
throws IndexSizeExceededException
{
IncrementalIndexRowResult incrementalIndexRowResult = toIncrementalIndexRow(row);
inputRowHolder.set(row);
final AddToFactsResult addToFactsResult = addToFacts(
row,
incrementalIndexRowResult.getIncrementalIndexRow(),
in,
rowSupplier,
inputRowHolder,
skipMaxRowsInMemoryCheck
);
updateMaxIngestedTime(row.getTimestamp());
@ -518,6 +554,7 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
incrementalIndexRowResult.getParseExceptionMessages(),
addToFactsResult.getParseExceptionMessages()
);
inputRowHolder.unset();
return new IncrementalIndexAddResult(
addToFactsResult.getRowCount(),
addToFactsResult.getBytesInMemory(),
@ -1022,11 +1059,11 @@ public abstract class IncrementalIndex implements Iterable<Row>, Closeable, Colu
}
protected ColumnSelectorFactory makeColumnSelectorFactory(
final AggregatorFactory agg,
final Supplier<InputRow> in
@Nullable final AggregatorFactory agg,
final InputRowHolder in
)
{
return makeColumnSelectorFactory(virtualColumns, agg, in);
return makeColumnSelectorFactory(virtualColumns, in, agg);
}
protected final Comparator<IncrementalIndexRow> dimsComparator()

View File

@ -21,12 +21,11 @@ package org.apache.druid.segment.incremental;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Supplier;
import com.google.common.collect.Iterators;
import com.google.common.collect.Maps;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.data.input.Row;
import org.apache.druid.error.DruidException;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
@ -42,6 +41,7 @@ import org.apache.druid.segment.DimensionHandler;
import org.apache.druid.segment.DimensionIndexer;
import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ValueType;
import org.apache.druid.utils.JvmUtils;
import javax.annotation.Nullable;
@ -118,10 +118,17 @@ public class OnheapIncrementalIndex extends IncrementalIndex
*/
private final boolean useMaxMemoryEstimates;
/**
* Aggregator name -> column selector factory for that aggregator.
*/
@Nullable
private volatile Map<String, ColumnSelectorFactory> selectors;
private Map<String, ColumnSelectorFactory> selectors;
/**
* Aggregator name -> column selector factory for the combining version of that aggregator. Only set when
* {@link #preserveExistingMetrics} is true.
*/
@Nullable
private volatile Map<String, ColumnSelectorFactory> combiningAggSelectors;
private Map<String, ColumnSelectorFactory> combiningAggSelectors;
@Nullable
private String outOfRowsReason = null;
@ -190,34 +197,49 @@ public class OnheapIncrementalIndex extends IncrementalIndex
@Override
protected void initAggs(
final AggregatorFactory[] metrics,
final Supplier<InputRow> rowSupplier
final InputRowHolder inputRowHolder
)
{
// All non-complex aggregators share a column selector factory. Helps with value reuse.
ColumnSelectorFactory nonComplexColumnSelectorFactory = null;
selectors = new HashMap<>();
combiningAggSelectors = new HashMap<>();
for (AggregatorFactory agg : metrics) {
selectors.put(
agg.getName(),
new CachingColumnSelectorFactory(makeColumnSelectorFactory(agg, rowSupplier))
);
if (preserveExistingMetrics) {
AggregatorFactory combiningAgg = agg.getCombiningFactory();
combiningAggSelectors.put(
combiningAgg.getName(),
new CachingColumnSelectorFactory(
makeColumnSelectorFactory(combiningAgg, rowSupplier)
)
);
final ColumnSelectorFactory factory;
if (agg.getIntermediateType().is(ValueType.COMPLEX)) {
factory = new CachingColumnSelectorFactory(makeColumnSelectorFactory(agg, inputRowHolder));
} else {
if (nonComplexColumnSelectorFactory == null) {
nonComplexColumnSelectorFactory =
new CachingColumnSelectorFactory(makeColumnSelectorFactory(null, inputRowHolder));
}
factory = nonComplexColumnSelectorFactory;
}
selectors.put(agg.getName(), factory);
}
if (preserveExistingMetrics) {
for (AggregatorFactory agg : metrics) {
final AggregatorFactory combiningAgg = agg.getCombiningFactory();
final ColumnSelectorFactory factory;
if (combiningAgg.getIntermediateType().is(ValueType.COMPLEX)) {
factory = new CachingColumnSelectorFactory(makeColumnSelectorFactory(combiningAgg, inputRowHolder));
} else {
if (nonComplexColumnSelectorFactory == null) {
nonComplexColumnSelectorFactory =
new CachingColumnSelectorFactory(makeColumnSelectorFactory(null, inputRowHolder));
}
factory = nonComplexColumnSelectorFactory;
}
combiningAggSelectors.put(combiningAgg.getName(), factory);
}
}
}
@Override
protected AddToFactsResult addToFacts(
InputRow row,
IncrementalIndexRow key,
ThreadLocal<InputRow> rowContainer,
Supplier<InputRow> rowSupplier,
InputRowHolder inputRowHolder,
boolean skipMaxRowsInMemoryCheck
) throws IndexSizeExceededException
{
@ -230,7 +252,7 @@ public class OnheapIncrementalIndex extends IncrementalIndex
final AtomicLong totalSizeInBytes = getBytesInMemory();
if (IncrementalIndexRow.EMPTY_ROW_INDEX != priorIndex) {
aggs = concurrentGet(priorIndex);
long aggSizeDelta = doAggregate(metrics, aggs, rowContainer, row, parseExceptionMessages);
long aggSizeDelta = doAggregate(metrics, aggs, inputRowHolder, parseExceptionMessages);
totalSizeInBytes.addAndGet(useMaxMemoryEstimates ? 0 : aggSizeDelta);
} else {
if (preserveExistingMetrics) {
@ -238,8 +260,8 @@ public class OnheapIncrementalIndex extends IncrementalIndex
} else {
aggs = new Aggregator[metrics.length];
}
long aggSizeForRow = factorizeAggs(metrics, aggs, rowContainer, row);
aggSizeForRow += doAggregate(metrics, aggs, rowContainer, row, parseExceptionMessages);
long aggSizeForRow = factorizeAggs(metrics, aggs);
aggSizeForRow += doAggregate(metrics, aggs, inputRowHolder, parseExceptionMessages);
final int rowIndex = indexIncrement.getAndIncrement();
concurrentSet(rowIndex, aggs);
@ -258,15 +280,7 @@ public class OnheapIncrementalIndex extends IncrementalIndex
if (IncrementalIndexRow.EMPTY_ROW_INDEX == prev) {
numEntries.incrementAndGet();
} else {
// this should never happen. Previously, this would happen in a race condition involving multiple write threads
// for GroupBy v1 strategy, but it is no more, so this code needs the concurrency model reworked in the future
parseExceptionMessages.clear();
aggs = concurrentGet(prev);
aggSizeForRow = doAggregate(metrics, aggs, rowContainer, row, parseExceptionMessages);
// Free up the misfire
concurrentRemove(rowIndex);
// This is expected to occur ~80% of the time in the worst scenarios
throw DruidException.defensive("Encountered existing fact entry for new key, possible concurrent add?");
}
// For a new key, row size = key size + aggregator size + overhead
@ -295,13 +309,10 @@ public class OnheapIncrementalIndex extends IncrementalIndex
*/
private long factorizeAggs(
AggregatorFactory[] metrics,
Aggregator[] aggs,
ThreadLocal<InputRow> rowContainer,
InputRow row
Aggregator[] aggs
)
{
long totalInitialSizeBytes = 0L;
rowContainer.set(row);
final long aggReferenceSize = Long.BYTES;
for (int i = 0; i < metrics.length; i++) {
final AggregatorFactory agg = metrics[i];
@ -328,7 +339,6 @@ public class OnheapIncrementalIndex extends IncrementalIndex
}
}
}
rowContainer.set(null);
return totalInitialSizeBytes;
}
@ -342,42 +352,44 @@ public class OnheapIncrementalIndex extends IncrementalIndex
private long doAggregate(
AggregatorFactory[] metrics,
Aggregator[] aggs,
ThreadLocal<InputRow> rowContainer,
InputRow row,
InputRowHolder inputRowHolder,
List<String> parseExceptionsHolder
)
{
rowContainer.set(row);
long totalIncrementalBytes = 0L;
for (int i = 0; i < metrics.length; i++) {
final Aggregator agg;
if (preserveExistingMetrics && row instanceof MapBasedRow && ((MapBasedRow) row).getEvent().containsKey(metrics[i].getName())) {
if (preserveExistingMetrics
&& inputRowHolder.getRow() instanceof MapBasedRow
&& ((MapBasedRow) inputRowHolder.getRow()).getEvent().containsKey(metrics[i].getName())) {
agg = aggs[i + metrics.length];
} else {
agg = aggs[i];
}
synchronized (agg) {
try {
if (useMaxMemoryEstimates) {
agg.aggregate();
} else {
totalIncrementalBytes += agg.aggregateWithSize();
}
try {
if (useMaxMemoryEstimates) {
agg.aggregate();
} else {
totalIncrementalBytes += agg.aggregateWithSize();
}
catch (ParseException e) {
// "aggregate" can throw ParseExceptions if a selector expects something but gets something else.
if (preserveExistingMetrics) {
log.warn(e, "Failing ingestion as preserveExistingMetrics is enabled but selector of aggregator[%s] recieved incompatible type.", metrics[i].getName());
throw e;
} else {
log.debug(e, "Encountered parse error, skipping aggregator[%s].", metrics[i].getName());
parseExceptionsHolder.add(e.getMessage());
}
}
catch (ParseException e) {
// "aggregate" can throw ParseExceptions if a selector expects something but gets something else.
if (preserveExistingMetrics) {
log.warn(
e,
"Failing ingestion as preserveExistingMetrics is enabled but selector of aggregator[%s] received "
+ "incompatible type.",
metrics[i].getName()
);
throw e;
} else {
log.debug(e, "Encountered parse error, skipping aggregator[%s].", metrics[i].getName());
parseExceptionsHolder.add(e.getMessage());
}
}
}
rowContainer.set(null);
return totalIncrementalBytes;
}
@ -409,11 +421,6 @@ public class OnheapIncrementalIndex extends IncrementalIndex
aggregators.put(offset, value);
}
protected void concurrentRemove(int offset)
{
aggregators.remove(offset);
}
@Override
public boolean canAppendRow()
{

View File

@ -36,7 +36,6 @@ import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.guice.NestedDataModule;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Accumulator;
import org.apache.druid.java.util.common.guava.Sequence;
@ -87,9 +86,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
/**
@ -463,11 +460,11 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
final IncrementalIndex index = indexCreator.createIndex(
(Object) ingestAggregatorFactories.toArray(new AggregatorFactory[0])
);
final int concurrentThreads = 2;
final int addThreads = 1;
final int elementsPerThread = 10_000;
final ListeningExecutorService indexExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
addThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("index-executor-%d")
@ -477,7 +474,7 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
);
final ListeningExecutorService queryExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
addThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("query-executor-%d")
@ -486,8 +483,8 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
);
final long timestamp = System.currentTimeMillis();
final Interval queryInterval = Intervals.of("1900-01-01T00:00:00Z/2900-01-01T00:00:00Z");
final List<ListenableFuture<?>> indexFutures = Lists.newArrayListWithExpectedSize(concurrentThreads);
final List<ListenableFuture<?>> queryFutures = Lists.newArrayListWithExpectedSize(concurrentThreads);
final List<ListenableFuture<?>> indexFutures = Lists.newArrayListWithExpectedSize(addThreads);
final List<ListenableFuture<?>> queryFutures = Lists.newArrayListWithExpectedSize(addThreads);
final Segment incrementalIndexSegment = new IncrementalIndexSegment(index, null);
final QueryRunnerFactory factory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(),
@ -498,9 +495,9 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
final AtomicInteger concurrentlyRan = new AtomicInteger(0);
final AtomicInteger someoneRan = new AtomicInteger(0);
final CountDownLatch startLatch = new CountDownLatch(1);
final CountDownLatch readyLatch = new CountDownLatch(concurrentThreads * 2);
final CountDownLatch readyLatch = new CountDownLatch(addThreads * 2);
final AtomicInteger queriesAccumualted = new AtomicInteger(0);
for (int j = 0; j < concurrentThreads; j++) {
for (int j = 0; j < addThreads; j++) {
indexFutures.add(
indexExecutor.submit(
new Runnable()
@ -577,7 +574,7 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
}
);
for (Double result : results) {
final Integer maxValueExpected = someoneRan.get() + concurrentThreads;
final int maxValueExpected = someoneRan.get() + addThreads;
if (maxValueExpected > 0) {
// Eventually consistent, but should be somewhere in that range
// Actual result is validated after all writes are guaranteed done.
@ -617,70 +614,24 @@ public class IncrementalIndexTest extends InitializedNullHandlingTest
boolean isRollup = index.isRollup();
for (Result<TimeseriesResultValue> result : results) {
Assert.assertEquals(
elementsPerThread * (isRollup ? 1 : concurrentThreads),
elementsPerThread * (isRollup ? 1 : addThreads),
result.getValue().getLongMetric("rows").intValue()
);
for (int i = 0; i < dimensionCount; ++i) {
Assert.assertEquals(
StringUtils.format("Failed long sum on dimension %d", i),
elementsPerThread * concurrentThreads,
elementsPerThread * addThreads,
result.getValue().getLongMetric(StringUtils.format("sumResult%s", i)).intValue()
);
Assert.assertEquals(
StringUtils.format("Failed double sum on dimension %d", i),
elementsPerThread * concurrentThreads,
elementsPerThread * addThreads,
result.getValue().getDoubleMetric(StringUtils.format("doubleSumResult%s", i)).intValue()
);
}
}
}
@Test
public void testConcurrentAdd() throws Exception
{
final IncrementalIndex index = indexCreator.createIndex((Object) DEFAULT_AGGREGATOR_FACTORIES);
final int threadCount = 10;
final int elementsPerThread = 200;
final int dimensionCount = 5;
ExecutorService executor = Execs.multiThreaded(threadCount, "IncrementalIndexTest-%d");
final long timestamp = System.currentTimeMillis();
final CountDownLatch latch = new CountDownLatch(threadCount);
for (int j = 0; j < threadCount; j++) {
executor.submit(
new Runnable()
{
@Override
public void run()
{
try {
for (int i = 0; i < elementsPerThread; i++) {
index.add(getRow(timestamp + i, i, dimensionCount));
}
}
catch (Exception e) {
e.printStackTrace();
}
latch.countDown();
}
}
);
}
Assert.assertTrue(latch.await(60, TimeUnit.SECONDS));
boolean isRollup = index.isRollup();
Assert.assertEquals(dimensionCount, index.getDimensionNames().size());
Assert.assertEquals(elementsPerThread * (isRollup ? 1 : threadCount), index.size());
Iterator<Row> iterator = index.iterator();
int curr = 0;
while (iterator.hasNext()) {
Row row = iterator.next();
Assert.assertEquals(timestamp + (isRollup ? curr : curr / threadCount), row.getTimestampFromEpoch());
Assert.assertEquals(isRollup ? threadCount : 1, row.getMetric("count").intValue());
curr++;
}
Assert.assertEquals(elementsPerThread * (isRollup ? 1 : threadCount), curr);
}
@Test
public void testgetDimensions()
{

View File

@ -20,22 +20,16 @@
package org.apache.druid.segment.incremental;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.guice.NestedDataModule;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.js.JavaScriptConfig;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.JavaScriptAggregatorFactory;
import org.apache.druid.query.aggregation.LongMaxAggregator;
import org.apache.druid.query.aggregation.LongMaxAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.segment.CloserRule;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.easymock.EasyMock;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -43,9 +37,6 @@ import org.junit.runners.Parameterized;
import java.util.Collection;
import java.util.Collections;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
@RunWith(Parameterized.class)
public class IncrementalIndexIngestionTest extends InitializedNullHandlingTest
@ -73,149 +64,6 @@ public class IncrementalIndexIngestionTest extends InitializedNullHandlingTest
return IncrementalIndexCreator.getAppendableIndexTypes();
}
@Test
public void testMultithreadAddFacts() throws Exception
{
final IncrementalIndex index = indexCreator.createIndex(new IncrementalIndexSchema.Builder()
.withQueryGranularity(Granularities.MINUTE)
.withMetrics(new LongMaxAggregatorFactory("max", "max"))
.build()
);
final int addThreadCount = 2;
Thread[] addThreads = new Thread[addThreadCount];
for (int i = 0; i < addThreadCount; ++i) {
addThreads[i] = new Thread(new Runnable()
{
@Override
public void run()
{
final Random random = ThreadLocalRandom.current();
try {
for (int j = 0; j < MAX_ROWS / addThreadCount; ++j) {
index.add(new MapBasedInputRow(
0,
Collections.singletonList("billy"),
ImmutableMap.of("billy", random.nextLong(), "max", 1)
));
}
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
});
addThreads[i].start();
}
final AtomicInteger checkFailedCount = new AtomicInteger(0);
Thread checkThread = new Thread(new Runnable()
{
@Override
public void run()
{
while (!Thread.interrupted()) {
for (IncrementalIndexRow row : index.getFacts().keySet()) {
if (index.getMetricLongValue(row.getRowIndex(), 0) != 1) {
checkFailedCount.addAndGet(1);
}
}
}
}
});
checkThread.start();
for (int i = 0; i < addThreadCount; ++i) {
addThreads[i].join();
}
checkThread.interrupt();
Assert.assertEquals(0, checkFailedCount.get());
}
@Test
public void testMultithreadAddFactsUsingExpressionAndJavaScript() throws Exception
{
final IncrementalIndex indexExpr = indexCreator.createIndex(
new IncrementalIndexSchema.Builder()
.withQueryGranularity(Granularities.MINUTE)
.withMetrics(new LongSumAggregatorFactory(
"oddnum",
null,
"if(value%2==1,1,0)",
TestExprMacroTable.INSTANCE
))
.withRollup(true)
.build()
);
final IncrementalIndex indexJs = indexCreator.createIndex(
new IncrementalIndexSchema.Builder()
.withQueryGranularity(Granularities.MINUTE)
.withMetrics(new JavaScriptAggregatorFactory(
"oddnum",
ImmutableList.of("value"),
"function(current, value) { if (value%2==1) current = current + 1; return current;}",
"function() {return 0;}",
"function(a, b) { return a + b;}",
JavaScriptConfig.getEnabledInstance()
))
.withRollup(true)
.build()
);
final int addThreadCount = 2;
Thread[] addThreads = new Thread[addThreadCount];
for (int i = 0; i < addThreadCount; ++i) {
addThreads[i] = new Thread(new Runnable()
{
@Override
public void run()
{
final Random random = ThreadLocalRandom.current();
try {
for (int j = 0; j < MAX_ROWS / addThreadCount; ++j) {
int randomInt = random.nextInt(100000);
MapBasedInputRow mapBasedInputRowExpr = new MapBasedInputRow(
0,
Collections.singletonList("billy"),
ImmutableMap.of("billy", randomInt % 3, "value", randomInt)
);
MapBasedInputRow mapBasedInputRowJs = new MapBasedInputRow(
0,
Collections.singletonList("billy"),
ImmutableMap.of("billy", randomInt % 3, "value", randomInt)
);
indexExpr.add(mapBasedInputRowExpr);
indexJs.add(mapBasedInputRowJs);
}
}
catch (Exception e) {
throw new RuntimeException(e);
}
}
});
addThreads[i].start();
}
for (int i = 0; i < addThreadCount; ++i) {
addThreads[i].join();
}
long exprSum = 0;
long jsSum = 0;
for (IncrementalIndexRow row : indexExpr.getFacts().keySet()) {
exprSum += indexExpr.getMetricLongValue(row.getRowIndex(), 0);
}
for (IncrementalIndexRow row : indexJs.getFacts().keySet()) {
jsSum += indexJs.getMetricLongValue(row.getRowIndex(), 0);
}
Assert.assertEquals(exprSum, jsSum);
}
@Test
public void testOnHeapIncrementalIndexClose() throws Exception
{

View File

@ -1,428 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.segment.incremental;
import com.carrotsearch.junitbenchmarks.AbstractBenchmark;
import com.carrotsearch.junitbenchmarks.BenchmarkOptions;
import com.carrotsearch.junitbenchmarks.Clock;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.MapBasedInputRow;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.Druids;
import org.apache.druid.query.FinalizeResultsQueryRunner;
import org.apache.druid.query.QueryPlus;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QueryRunnerFactory;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.query.timeseries.TimeseriesQueryEngine;
import org.apache.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import org.apache.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import org.apache.druid.query.timeseries.TimeseriesResultValue;
import org.apache.druid.segment.IncrementalIndexSegment;
import org.apache.druid.segment.Segment;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* Extending AbstractBenchmark means only runs if explicitly called
*/
@RunWith(Parameterized.class)
public class OnheapIncrementalIndexBenchmark extends AbstractBenchmark
{
private static AggregatorFactory[] factories;
static final int DIMENSION_COUNT = 5;
static {
final ArrayList<AggregatorFactory> ingestAggregatorFactories = new ArrayList<>(DIMENSION_COUNT + 1);
ingestAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < DIMENSION_COUNT; ++i) {
ingestAggregatorFactories.add(
new LongSumAggregatorFactory(
StringUtils.format("sumResult%s", i),
StringUtils.format("Dim_%s", i)
)
);
ingestAggregatorFactories.add(
new DoubleSumAggregatorFactory(
StringUtils.format("doubleSumResult%s", i),
StringUtils.format("Dim_%s", i)
)
);
}
factories = ingestAggregatorFactories.toArray(new AggregatorFactory[0]);
}
private static final class MapIncrementalIndex extends OnheapIncrementalIndex
{
private final AtomicInteger indexIncrement = new AtomicInteger(0);
ConcurrentHashMap<Integer, Aggregator[]> indexedMap = new ConcurrentHashMap<Integer, Aggregator[]>();
public MapIncrementalIndex(
IncrementalIndexSchema incrementalIndexSchema,
int maxRowCount,
long maxBytesInMemory
)
{
super(
incrementalIndexSchema,
maxRowCount,
maxBytesInMemory,
false,
true
);
}
public MapIncrementalIndex(
long minTimestamp,
Granularity gran,
AggregatorFactory[] metrics,
int maxRowCount,
long maxBytesInMemory
)
{
super(
new IncrementalIndexSchema.Builder()
.withMinTimestamp(minTimestamp)
.withQueryGranularity(gran)
.withMetrics(metrics)
.build(),
maxRowCount,
maxBytesInMemory,
false,
true
);
}
@Override
protected Aggregator[] concurrentGet(int offset)
{
// All get operations should be fine
return indexedMap.get(offset);
}
@Override
protected void concurrentSet(int offset, Aggregator[] value)
{
indexedMap.put(offset, value);
}
@Override
protected AddToFactsResult addToFacts(
InputRow row,
IncrementalIndexRow key,
ThreadLocal<InputRow> rowContainer,
Supplier<InputRow> rowSupplier,
boolean skipMaxRowsInMemoryCheck // ignore for benchmark
) throws IndexSizeExceededException
{
final Integer priorIdex = getFacts().getPriorIndex(key);
Aggregator[] aggs;
final AggregatorFactory[] metrics = getMetrics();
final AtomicInteger numEntries = getNumEntries();
final AtomicLong sizeInBytes = getBytesInMemory();
if (null != priorIdex) {
aggs = indexedMap.get(priorIdex);
} else {
aggs = new Aggregator[metrics.length];
for (int i = 0; i < metrics.length; i++) {
final AggregatorFactory agg = metrics[i];
aggs[i] = agg.factorize(
makeColumnSelectorFactory(agg, rowSupplier)
);
}
Integer rowIndex;
do {
rowIndex = indexIncrement.incrementAndGet();
} while (null != indexedMap.putIfAbsent(rowIndex, aggs));
// Last ditch sanity checks
if ((numEntries.get() >= maxRowCount || sizeInBytes.get() >= maxBytesInMemory)
&& getFacts().getPriorIndex(key) == IncrementalIndexRow.EMPTY_ROW_INDEX) {
throw new IndexSizeExceededException("Maximum number of rows or max bytes reached");
}
final int prev = getFacts().putIfAbsent(key, rowIndex);
if (IncrementalIndexRow.EMPTY_ROW_INDEX == prev) {
numEntries.incrementAndGet();
sizeInBytes.incrementAndGet();
} else {
// We lost a race
aggs = indexedMap.get(prev);
// Free up the misfire
indexedMap.remove(rowIndex);
// This is expected to occur ~80% of the time in the worst scenarios
}
}
rowContainer.set(row);
for (Aggregator agg : aggs) {
synchronized (agg) {
agg.aggregate();
}
}
rowContainer.set(null);
return new AddToFactsResult(numEntries.get(), sizeInBytes.get(), new ArrayList<>());
}
@Override
public int getLastRowIndex()
{
return indexIncrement.get() - 1;
}
}
@Parameterized.Parameters
public static Collection<Object[]> getParameters()
{
return ImmutableList.of(
new Object[]{OnheapIncrementalIndex.class},
new Object[]{MapIncrementalIndex.class}
);
}
private final Class<? extends OnheapIncrementalIndex> incrementalIndex;
public OnheapIncrementalIndexBenchmark(Class<? extends OnheapIncrementalIndex> incrementalIndex)
{
this.incrementalIndex = incrementalIndex;
}
private static MapBasedInputRow getLongRow(long timestamp, int rowID, int dimensionCount)
{
List<String> dimensionList = new ArrayList<String>(dimensionCount);
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
for (int i = 0; i < dimensionCount; i++) {
String dimName = StringUtils.format("Dim_%d", i);
dimensionList.add(dimName);
builder.put(dimName, new Integer(rowID).longValue());
}
return new MapBasedInputRow(timestamp, dimensionList, builder.build());
}
@Ignore
@Test
@BenchmarkOptions(callgc = true, clock = Clock.REAL_TIME, warmupRounds = 10, benchmarkRounds = 20)
public void testConcurrentAddRead()
throws InterruptedException, ExecutionException, NoSuchMethodException, IllegalAccessException,
InvocationTargetException, InstantiationException
{
final int taskCount = 30;
final int concurrentThreads = 3;
final int elementsPerThread = 1 << 15;
final IncrementalIndex incrementalIndex = this.incrementalIndex.getConstructor(
IncrementalIndexSchema.class,
boolean.class,
boolean.class,
boolean.class,
boolean.class,
int.class
).newInstance(
new IncrementalIndexSchema.Builder().withMetrics(factories).build(),
true,
true,
false,
true,
elementsPerThread * taskCount
);
final ArrayList<AggregatorFactory> queryAggregatorFactories = new ArrayList<>(DIMENSION_COUNT + 1);
queryAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < DIMENSION_COUNT; ++i) {
queryAggregatorFactories.add(
new LongSumAggregatorFactory(
StringUtils.format("sumResult%s", i),
StringUtils.format("sumResult%s", i)
)
);
queryAggregatorFactories.add(
new DoubleSumAggregatorFactory(
StringUtils.format("doubleSumResult%s", i),
StringUtils.format("doubleSumResult%s", i)
)
);
}
final ListeningExecutorService indexExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("index-executor-%d")
.setPriority(Thread.MIN_PRIORITY)
.build()
)
);
final ListeningExecutorService queryExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("query-executor-%d")
.build()
)
);
final long timestamp = System.currentTimeMillis();
final Interval queryInterval = Intervals.of("1900-01-01T00:00:00Z/2900-01-01T00:00:00Z");
final List<ListenableFuture<?>> indexFutures = new ArrayList<>();
final List<ListenableFuture<?>> queryFutures = new ArrayList<>();
final Segment incrementalIndexSegment = new IncrementalIndexSegment(incrementalIndex, null);
final QueryRunnerFactory factory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
);
final AtomicInteger currentlyRunning = new AtomicInteger(0);
final AtomicBoolean concurrentlyRan = new AtomicBoolean(false);
final AtomicBoolean someoneRan = new AtomicBoolean(false);
for (int j = 0; j < taskCount; j++) {
indexFutures.add(
indexExecutor.submit(
new Runnable()
{
@Override
public void run()
{
currentlyRunning.incrementAndGet();
try {
for (int i = 0; i < elementsPerThread; i++) {
incrementalIndex.add(getLongRow(timestamp + i, 1, DIMENSION_COUNT));
}
}
catch (IndexSizeExceededException e) {
throw new RuntimeException(e);
}
currentlyRunning.decrementAndGet();
someoneRan.set(true);
}
}
)
);
queryFutures.add(
queryExecutor.submit(
new Runnable()
{
@Override
public void run()
{
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(Granularities.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
List<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
for (Result<TimeseriesResultValue> result : results) {
if (someoneRan.get()) {
Assert.assertTrue(result.getValue().getDoubleMetric("doubleSumResult0") > 0);
}
}
if (currentlyRunning.get() > 0) {
concurrentlyRan.set(true);
}
}
}
)
);
}
List<ListenableFuture<?>> allFutures = new ArrayList<>(queryFutures.size() + indexFutures.size());
allFutures.addAll(queryFutures);
allFutures.addAll(indexFutures);
Futures.allAsList(allFutures).get();
//Assert.assertTrue("Did not hit concurrency, please try again", concurrentlyRan.get());
queryExecutor.shutdown();
indexExecutor.shutdown();
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(Granularities.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
List<Result<TimeseriesResultValue>> results = runner.run(QueryPlus.wrap(query)).toList();
final int expectedVal = elementsPerThread * taskCount;
for (Result<TimeseriesResultValue> result : results) {
Assert.assertEquals(elementsPerThread, result.getValue().getLongMetric("rows").intValue());
for (int i = 0; i < DIMENSION_COUNT; ++i) {
Assert.assertEquals(
StringUtils.format("Failed long sum on dimension %d", i),
expectedVal,
result.getValue().getLongMetric(StringUtils.format("sumResult%s", i)).intValue()
);
Assert.assertEquals(
StringUtils.format("Failed double sum on dimension %d", i),
expectedVal,
result.getValue().getDoubleMetric(StringUtils.format("doubleSumResult%s", i)).intValue()
);
}
}
}
}