Merge pull request #1027 from metamx/concurrentOnHeapIncrementalIndexFix

Fix concurrency issues in OnheapIncrementalIndex
This commit is contained in:
Xavier Léauté 2015-01-16 12:54:42 -08:00
commit 3b3aad78cb
5 changed files with 999 additions and 45 deletions

View File

@ -21,7 +21,6 @@ package io.druid.segment.incremental;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.metamx.common.ISE;
import io.druid.data.input.InputRow;
@ -30,8 +29,8 @@ import io.druid.query.aggregation.Aggregator;
import io.druid.query.aggregation.AggregatorFactory;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
@ -41,16 +40,20 @@ import java.util.concurrent.atomic.AtomicInteger;
*/
public class OnheapIncrementalIndex extends IncrementalIndex<Aggregator>
{
private final ConcurrentNavigableMap<TimeAndDims, Integer> facts;
private final List<Aggregator[]> aggList = Lists.newArrayList();
private final int maxRowCount;
private final ConcurrentHashMap<Integer, Aggregator[]> aggregators = new ConcurrentHashMap<>();
private final ConcurrentNavigableMap<TimeAndDims, Integer> facts = new ConcurrentSkipListMap<>();
private final AtomicInteger indexIncrement = new AtomicInteger(0);
protected final int maxRowCount;
private String outOfRowsReason = null;
public OnheapIncrementalIndex(IncrementalIndexSchema incrementalIndexSchema, boolean deserializeComplexMetrics, int maxRowCount)
public OnheapIncrementalIndex(
IncrementalIndexSchema incrementalIndexSchema,
boolean deserializeComplexMetrics,
int maxRowCount
)
{
super(incrementalIndexSchema, deserializeComplexMetrics);
this.facts = new ConcurrentSkipListMap<>();
this.maxRowCount = maxRowCount;
}
@ -127,40 +130,70 @@ public class OnheapIncrementalIndex extends IncrementalIndex<Aggregator>
ThreadLocal<InputRow> in
) throws IndexSizeExceededException
{
Integer rowOffset;
synchronized (this) {
rowOffset = numEntries.get();
if(rowOffset >= maxRowCount && !facts.containsKey(key)) {
throw new IndexSizeExceededException("Maximum number of rows reached");
}
final Integer prev = facts.putIfAbsent(key, rowOffset);
if (prev != null) {
rowOffset = prev;
final Integer priorIndex = facts.get(key);
Aggregator[] aggs;
if (null != priorIndex) {
aggs = concurrentGet(priorIndex);
} else {
Aggregator[] aggs = new Aggregator[metrics.length];
aggs = new Aggregator[metrics.length];
for (int i = 0; i < metrics.length; i++) {
final AggregatorFactory agg = metrics[i];
aggs[i] = agg.factorize(
makeColumnSelectorFactory(agg, in, deserializeComplexMetrics)
);
}
aggList.add(aggs);
final Integer rowIndex = indexIncrement.getAndIncrement();
concurrentSet(rowIndex, aggs);
// Last ditch sanity checks
if (numEntries.get() >= maxRowCount && !facts.containsKey(key)) {
throw new IndexSizeExceededException("Maximum number of rows reached");
}
final Integer prev = facts.putIfAbsent(key, rowIndex);
if (null == prev) {
numEntries.incrementAndGet();
} else {
// We lost a race
aggs = concurrentGet(prev);
// Free up the misfire
concurrentRemove(rowIndex);
// This is expected to occur ~80% of the time in the worst scenarios
}
}
in.set(row);
final Aggregator[] aggs = aggList.get(rowOffset);
for (int i = 0; i < aggs.length; i++) {
synchronized (aggs[i]) {
aggs[i].aggregate();
for (Aggregator agg : aggs) {
synchronized (agg) {
agg.aggregate();
}
}
in.set(null);
return numEntries.get();
}
protected Aggregator[] concurrentGet(int offset)
{
// All get operations should be fine
return aggregators.get(offset);
}
protected void concurrentSet(int offset, Aggregator[] value)
{
aggregators.put(offset, value);
}
protected void concurrentRemove(int offset)
{
aggregators.remove(offset);
}
@Override
public boolean canAppendRow()
{
@ -180,7 +213,7 @@ public class OnheapIncrementalIndex extends IncrementalIndex<Aggregator>
@Override
protected Aggregator[] getAggsForRow(int rowOffset)
{
return aggList.get(rowOffset);
return concurrentGet(rowOffset);
}
@Override
@ -192,19 +225,19 @@ public class OnheapIncrementalIndex extends IncrementalIndex<Aggregator>
@Override
public float getMetricFloatValue(int rowOffset, int aggOffset)
{
return aggList.get(rowOffset)[aggOffset].getFloat();
return concurrentGet(rowOffset)[aggOffset].getFloat();
}
@Override
public long getMetricLongValue(int rowOffset, int aggOffset)
{
return aggList.get(rowOffset)[aggOffset].getLong();
return concurrentGet(rowOffset)[aggOffset].getLong();
}
@Override
public Object getMetricObjectValue(int rowOffset, int aggOffset)
{
return aggList.get(rowOffset)[aggOffset].get();
return concurrentGet(rowOffset)[aggOffset].get();
}
private static class OnHeapDimDim implements DimDim

View File

@ -51,7 +51,7 @@ public class IndexMergerTest
{
final long timestamp = System.currentTimeMillis();
IncrementalIndex toPersist = IncrementalIndexTest.createIndex(true);
IncrementalIndex toPersist = IncrementalIndexTest.createIndex(true, null);
IncrementalIndexTest.populateIndex(timestamp, toPersist);
final File tempDir = Files.createTempDir();
@ -71,7 +71,7 @@ public class IndexMergerTest
public void testPersistMerge() throws Exception
{
final long timestamp = System.currentTimeMillis();
IncrementalIndex toPersist1 = IncrementalIndexTest.createIndex(true);
IncrementalIndex toPersist1 = IncrementalIndexTest.createIndex(true, null);
IncrementalIndexTest.populateIndex(timestamp, toPersist1);
IncrementalIndex toPersist2 = new OnheapIncrementalIndex(0L, QueryGranularity.NONE, new AggregatorFactory[]{new CountAggregatorFactory("count")}, 1000);

View File

@ -0,0 +1,304 @@
/*
* Druid - a distributed column store.
* Copyright (C) 2012, 2013, 2014 Metamarkets Group Inc.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package io.druid.segment.data;
import com.carrotsearch.junitbenchmarks.AbstractBenchmark;
import com.carrotsearch.junitbenchmarks.BenchmarkOptions;
import com.carrotsearch.junitbenchmarks.Clock;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
// AbstractBenchmark makes this ignored unless explicitly run
@RunWith(Parameterized.class)
public class BenchmarkIndexibleWrites extends AbstractBenchmark
{
@Parameterized.Parameters
public static Collection<Object[]> getParameters()
{
return ImmutableList.<Object[]>of(
new Object[]{new ConcurrentStandardMap<Integer>()},
new Object[]{new ConcurrentExpandable<Integer>()}
);
}
public BenchmarkIndexibleWrites(ConcurrentIndexible<Integer> concurrentIndexible)
{
this.concurrentIndexible = concurrentIndexible;
}
private static interface ConcurrentIndexible<V>
{
public void set(Integer index, V object);
public V get(Integer index);
public void clear();
}
private static class ConcurrentStandardMap<V> implements ConcurrentIndexible<V>
{
private final ConcurrentHashMap<Integer, V> delegate = new ConcurrentHashMap<>();
@Override
public void set(Integer index, V object)
{
delegate.put(index, object);
}
@Override
public V get(Integer index)
{
return delegate.get(index);
}
@Override
public void clear(){
delegate.clear();
}
}
private static class ConcurrentExpandable<V> implements ConcurrentIndexible<V>
{
private static Integer INIT_SIZE = 1 << 10;
private final AtomicReference<V[]> reference = new AtomicReference<>();
private final AtomicLong resizeCount = new AtomicLong(0);
private final Integer initSize;
public ConcurrentExpandable()
{
this(INIT_SIZE);
}
@SuppressWarnings("unchecked")
public ConcurrentExpandable(Integer initSize)
{
reference.set((V[]) new Object[initSize]);
this.initSize = initSize;
}
@Override
public V get(Integer index)
{
return reference.get()[index];
}
@SuppressWarnings("unchecked")
@Override
public void clear()
{
reference.set((V[]) new Object[initSize]);
}
private static Boolean wasCopying(Long val)
{
return (val & 1l) > 0;
}
@Override
public void set(Integer index, V object)
{
ensureCapacity(index + 1);
Long pre, post;
do {
pre = resizeCount.get();
reference.get()[index] = object;
post = resizeCount.get();
} while (wasCopying(pre) || wasCopying(post) || (!pre.equals(post)));
}
private final Object resizeMutex = new Object();
private void ensureCapacity(int capacity)
{
synchronized (resizeMutex) {
if (reference.get().length < capacity) {
// We increment twice per resize. Once before the copy starts and once after the swap.
//
// Any task who sees a resizeCount which is *odd* between the start and stop of their critical section
// has access to a nebulous aggList and should try again
//
// Any task who sees a resizeCount which changes between the start and stop of their critical section
// should also try again
resizeCount.incrementAndGet();
reference.set(Arrays.copyOf(reference.get(), reference.get().length<<1));
resizeCount.incrementAndGet();
}
}
}
}
private final ConcurrentIndexible<Integer> concurrentIndexible;
private final Integer concurrentThreads = 1<<2;
private final Integer totalIndexSize = 1<<20;
@BenchmarkOptions(warmupRounds = 100, benchmarkRounds = 100, clock = Clock.REAL_TIME, callgc = true)
@Test
/**
* CALLEN - 2015-01-15 - OSX - Java 1.7.0_71-b14
BenchmarkIndexibleWrites.testConcurrentWrites[0]: [measured 100 out of 200 rounds, threads: 1 (sequential)]
round: 0.24 [+- 0.01], round.block: 0.00 [+- 0.00], round.gc: 0.02 [+- 0.00], GC.calls: 396, GC.time: 1.88, time.total: 50.60, time.warmup: 24.84, time.bench: 25.77
BenchmarkIndexibleWrites.testConcurrentWrites[1]: [measured 100 out of 200 rounds, threads: 1 (sequential)]
round: 0.15 [+- 0.01], round.block: 0.00 [+- 0.00], round.gc: 0.02 [+- 0.00], GC.calls: 396, GC.time: 2.11, time.total: 33.14, time.warmup: 16.09, time.bench: 17.05
*/
public void testConcurrentWrites() throws ExecutionException, InterruptedException
{
final ListeningExecutorService executorService = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("indexible-writes-benchmark-%d")
.build()
)
);
final AtomicInteger index = new AtomicInteger(0);
List<ListenableFuture<?>> futures = new LinkedList<>();
final Integer loops = totalIndexSize / concurrentThreads;
for (int i = 0; i < concurrentThreads; ++i) {
futures.add(
executorService.submit(
new Runnable()
{
@Override
public void run()
{
for (int i = 0; i < loops; ++i) {
final Integer idx = index.getAndIncrement();
concurrentIndexible.set(idx, idx);
}
}
}
)
);
}
Futures.allAsList(futures).get();
Assert.assertTrue(String.format("Index too small %d, expected %d across %d loops", index.get(), totalIndexSize, loops), index.get()>=totalIndexSize);
for(int i = 0; i < index.get(); ++i){
Assert.assertEquals(i, concurrentIndexible.get(i).intValue());
}
concurrentIndexible.clear();
futures.clear();
executorService.shutdown();
}
/**
BenchmarkIndexibleWrites.TestConcurrentReads[0]: [measured 100 out of 200 rounds, threads: 1 (sequential)]
round: 0.28 [+- 0.02], round.block: 0.00 [+- 0.00], round.gc: 0.02 [+- 0.00], GC.calls: 396, GC.time: 1.84, time.total: 59.98, time.warmup: 30.51, time.bench: 29.48
BenchmarkIndexibleWrites.TestConcurrentReads[1]: [measured 100 out of 200 rounds, threads: 1 (sequential)]
round: 0.12 [+- 0.01], round.block: 0.00 [+- 0.00], round.gc: 0.02 [+- 0.00], GC.calls: 396, GC.time: 2.05, time.total: 29.21, time.warmup: 14.65, time.bench: 14.55
*/
@BenchmarkOptions(warmupRounds = 100, benchmarkRounds = 100, clock = Clock.REAL_TIME, callgc = true)
@Test
public void testConcurrentReads() throws ExecutionException, InterruptedException
{
final ListeningExecutorService executorService = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("indexible-writes-benchmark-reader-%d")
.build()
)
);
final AtomicInteger index = new AtomicInteger(0);
final AtomicInteger queryableIndex = new AtomicInteger(0);
List<ListenableFuture<?>> futures = new LinkedList<>();
final Integer loops = totalIndexSize / concurrentThreads;
final AtomicBoolean done = new AtomicBoolean(false);
final CountDownLatch start = new CountDownLatch(1);
for (int i = 0; i < concurrentThreads; ++i) {
futures.add(
executorService.submit(
new Runnable()
{
@Override
public void run()
{
try {
start.await();
}
catch (InterruptedException e) {
throw Throwables.propagate(e);
}
final Random rndGen = new Random();
while(!done.get()){
Integer idx = rndGen.nextInt(queryableIndex.get() + 1);
Assert.assertEquals(idx, concurrentIndexible.get(idx));
}
}
}
)
);
}
{
final Integer idx = index.getAndIncrement();
concurrentIndexible.set(idx, idx);
start.countDown();
}
for (int i = 1; i < totalIndexSize; ++i) {
final Integer idx = index.getAndIncrement();
concurrentIndexible.set(idx, idx);
queryableIndex.incrementAndGet();
}
done.set(true);
Futures.allAsList(futures).get();
executorService.shutdown();
Assert.assertTrue(String.format("Index too small %d, expected %d across %d loops", index.get(), totalIndexSize, loops), index.get()>=totalIndexSize);
for(int i = 0; i < index.get(); ++i){
Assert.assertEquals(i, concurrentIndexible.get(i).intValue());
}
concurrentIndexible.clear();
futures.clear();
}
}

View File

@ -19,17 +19,42 @@
package io.druid.segment.data;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.metamx.common.guava.Sequences;
import io.druid.data.input.MapBasedInputRow;
import io.druid.data.input.Row;
import io.druid.granularity.QueryGranularity;
import io.druid.query.Druids;
import io.druid.query.FinalizeResultsQueryRunner;
import io.druid.query.QueryConfig;
import io.druid.query.QueryRunner;
import io.druid.query.QueryRunnerFactory;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.Result;
import io.druid.query.TestQueryRunners;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.timeseries.TimeseriesQuery;
import io.druid.query.timeseries.TimeseriesQueryEngine;
import io.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import io.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import io.druid.query.timeseries.TimeseriesResultValue;
import io.druid.segment.IncrementalIndexSegment;
import io.druid.segment.Segment;
import io.druid.segment.incremental.IncrementalIndex;
import io.druid.segment.incremental.IndexSizeExceededException;
import io.druid.segment.incremental.OffheapIncrementalIndex;
import io.druid.segment.incremental.OnheapIncrementalIndex;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
@ -39,12 +64,18 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
*/
@ -53,7 +84,7 @@ public class IncrementalIndexTest
{
interface IndexCreator
{
public IncrementalIndex createIndex();
public IncrementalIndex createIndex(AggregatorFactory[] aggregatorFactories);
}
private final IndexCreator indexCreator;
@ -74,9 +105,9 @@ public class IncrementalIndexTest
new IndexCreator()
{
@Override
public IncrementalIndex createIndex()
public IncrementalIndex createIndex(AggregatorFactory[] factories)
{
return IncrementalIndexTest.createIndex(true);
return IncrementalIndexTest.createIndex(true, factories);
}
}
},
@ -84,9 +115,9 @@ public class IncrementalIndexTest
new IndexCreator()
{
@Override
public IncrementalIndex createIndex()
public IncrementalIndex createIndex(AggregatorFactory[] factories)
{
return IncrementalIndexTest.createIndex(false);
return IncrementalIndexTest.createIndex(false, factories);
}
}
}
@ -95,20 +126,23 @@ public class IncrementalIndexTest
);
}
public static IncrementalIndex createIndex(boolean offheap)
public static IncrementalIndex createIndex(boolean offheap, AggregatorFactory[] aggregatorFactories)
{
if (null == aggregatorFactories) {
aggregatorFactories = defaultAggregatorFactories;
}
if (offheap) {
return new OffheapIncrementalIndex(
0L,
QueryGranularity.NONE,
new AggregatorFactory[]{new CountAggregatorFactory("count")},
aggregatorFactories,
TestQueryRunners.pool,
true,
100 * 1024 * 1024
);
} else {
return new OnheapIncrementalIndex(
0L, QueryGranularity.NONE, new AggregatorFactory[]{new CountAggregatorFactory("count")}, 1000
0L, QueryGranularity.NONE, aggregatorFactories, 1000000
);
}
}
@ -144,11 +178,29 @@ public class IncrementalIndexTest
return new MapBasedInputRow(timestamp, dimensionList, builder.build());
}
private static MapBasedInputRow getLongRow(long timestamp, int rowID, int dimensionCount)
{
List<String> dimensionList = new ArrayList<String>(dimensionCount);
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
for (int i = 0; i < dimensionCount; i++) {
String dimName = String.format("Dim_%d", i);
dimensionList.add(dimName);
builder.put(dimName, (Long) 1l);
}
return new MapBasedInputRow(timestamp, dimensionList, builder.build());
}
private static final AggregatorFactory[] defaultAggregatorFactories = new AggregatorFactory[]{
new CountAggregatorFactory(
"count"
)
};
@Test
public void testCaseSensitivity() throws Exception
{
long timestamp = System.currentTimeMillis();
IncrementalIndex index = indexCreator.createIndex();
IncrementalIndex index = indexCreator.createIndex(defaultAggregatorFactories);
populateIndex(timestamp, index);
Assert.assertEquals(Arrays.asList("dim1", "dim2"), index.getDimensions());
Assert.assertEquals(2, index.size());
@ -165,10 +217,183 @@ public class IncrementalIndexTest
Assert.assertEquals(Arrays.asList("4"), row.getDimension("dim2"));
}
@Test(timeout = 60000)
public void testConcurrentAddRead() throws InterruptedException, ExecutionException
{
final int dimensionCount = 5;
final ArrayList<AggregatorFactory> ingestAggregatorFactories = new ArrayList<>(dimensionCount + 1);
ingestAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < dimensionCount; ++i) {
ingestAggregatorFactories.add(
new LongSumAggregatorFactory(
String.format("sumResult%s", i),
String.format("Dim_%s", i)
)
);
ingestAggregatorFactories.add(
new DoubleSumAggregatorFactory(
String.format("doubleSumResult%s", i),
String.format("Dim_%s", i)
)
);
}
final ArrayList<AggregatorFactory> queryAggregatorFactories = new ArrayList<>(dimensionCount + 1);
queryAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < dimensionCount; ++i) {
queryAggregatorFactories.add(
new LongSumAggregatorFactory(
String.format("sumResult%s", i),
String.format("sumResult%s", i)
)
);
queryAggregatorFactories.add(
new DoubleSumAggregatorFactory(
String.format("doubleSumResult%s", i),
String.format("doubleSumResult%s", i)
)
);
}
final IncrementalIndex index = indexCreator.createIndex(ingestAggregatorFactories.toArray(new AggregatorFactory[dimensionCount]));
final int taskCount = 30;
final int concurrentThreads = 3;
final int elementsPerThread = 100;
final ListeningExecutorService indexExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("index-executor-%d")
.setPriority(Thread.MIN_PRIORITY)
.build()
)
);
final ListeningExecutorService queryExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("query-executor-%d")
.build()
)
);
final long timestamp = System.currentTimeMillis();
final Interval queryInterval = new Interval("1900-01-01T00:00:00Z/2900-01-01T00:00:00Z");
final List<ListenableFuture<?>> indexFutures = new LinkedList<>();
final List<ListenableFuture<?>> queryFutures = new LinkedList<>();
final Segment incrementalIndexSegment = new IncrementalIndexSegment(index, null);
final QueryRunnerFactory factory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(new QueryConfig()),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
);
final AtomicInteger currentlyRunning = new AtomicInteger(0);
final AtomicBoolean concurrentlyRan = new AtomicBoolean(false);
final AtomicBoolean someoneRan = new AtomicBoolean(false);
for (int j = 0; j < taskCount; j++) {
indexFutures.add(
indexExecutor.submit(
new Runnable()
{
@Override
public void run()
{
currentlyRunning.incrementAndGet();
try {
for (int i = 0; i < elementsPerThread; i++) {
index.add(getLongRow(timestamp + i, i, dimensionCount));
}
}
catch (IndexSizeExceededException e) {
throw Throwables.propagate(e);
}
currentlyRunning.decrementAndGet();
someoneRan.set(true);
}
}
)
);
queryFutures.add(
queryExecutor.submit(
new Runnable()
{
@Override
public void run()
{
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(QueryGranularity.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
Map<String, Object> context = new HashMap<String, Object>();
for (Result<TimeseriesResultValue> result :
Sequences.toList(
runner.run(query, context),
new LinkedList<Result<TimeseriesResultValue>>()
)
) {
if (someoneRan.get()) {
Assert.assertTrue(result.getValue().getDoubleMetric("doubleSumResult0") > 0);
}
}
if (currentlyRunning.get() > 0) {
concurrentlyRan.set(true);
}
}
}
)
);
}
List<ListenableFuture<?>> allFutures = new ArrayList<>(queryFutures.size() + indexFutures.size());
allFutures.addAll(queryFutures);
allFutures.addAll(indexFutures);
Futures.allAsList(allFutures).get();
Assert.assertTrue("Did not hit concurrency, please try again", concurrentlyRan.get());
queryExecutor.shutdown();
indexExecutor.shutdown();
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(QueryGranularity.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
Map<String, Object> context = new HashMap<String, Object>();
List<Result<TimeseriesResultValue>> results = Sequences.toList(
runner.run(query, context),
new LinkedList<Result<TimeseriesResultValue>>()
);
for (Result<TimeseriesResultValue> result : results) {
Assert.assertEquals(elementsPerThread, result.getValue().getLongMetric("rows").intValue());
for (int i = 0; i < dimensionCount; ++i) {
Assert.assertEquals(
String.format("Failed long sum on dimension %d", i),
elementsPerThread * taskCount,
result.getValue().getLongMetric(String.format("sumResult%s", i)).intValue()
);
Assert.assertEquals(
String.format("Failed double sum on dimension %d", i),
elementsPerThread * taskCount,
result.getValue().getDoubleMetric(String.format("doubleSumResult%s", i)).intValue()
);
}
}
}
@Test
public void testConcurrentAdd() throws Exception
{
final IncrementalIndex index = indexCreator.createIndex();
final IncrementalIndex index = indexCreator.createIndex(defaultAggregatorFactories);
final int threadCount = 10;
final int elementsPerThread = 200;
final int dimensionCount = 5;

View File

@ -0,0 +1,392 @@
/*
* Druid - a distributed column store.
* Copyright (C) 2012, 2013, 2014 Metamarkets Group Inc.
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package io.druid.segment.incremental;
import com.carrotsearch.junitbenchmarks.AbstractBenchmark;
import com.carrotsearch.junitbenchmarks.BenchmarkOptions;
import com.carrotsearch.junitbenchmarks.Clock;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.metamx.common.guava.Sequences;
import io.druid.data.input.InputRow;
import io.druid.data.input.MapBasedInputRow;
import io.druid.granularity.QueryGranularity;
import io.druid.query.Druids;
import io.druid.query.FinalizeResultsQueryRunner;
import io.druid.query.QueryConfig;
import io.druid.query.QueryRunner;
import io.druid.query.QueryRunnerFactory;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.Result;
import io.druid.query.aggregation.Aggregator;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.DoubleSumAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.timeseries.TimeseriesQuery;
import io.druid.query.timeseries.TimeseriesQueryEngine;
import io.druid.query.timeseries.TimeseriesQueryQueryToolChest;
import io.druid.query.timeseries.TimeseriesQueryRunnerFactory;
import io.druid.query.timeseries.TimeseriesResultValue;
import io.druid.segment.IncrementalIndexSegment;
import io.druid.segment.Segment;
import org.joda.time.Interval;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Extending AbstractBenchmark means only runs if explicitly called
*/
@RunWith(Parameterized.class)
public class OnheapIncrementalIndexBenchmark extends AbstractBenchmark
{
private static AggregatorFactory[] factories;
static final int dimensionCount = 5;
static {
final ArrayList<AggregatorFactory> ingestAggregatorFactories = new ArrayList<>(dimensionCount + 1);
ingestAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < dimensionCount; ++i) {
ingestAggregatorFactories.add(
new LongSumAggregatorFactory(
String.format("sumResult%s", i),
String.format("Dim_%s", i)
)
);
ingestAggregatorFactories.add(
new DoubleSumAggregatorFactory(
String.format("doubleSumResult%s", i),
String.format("Dim_%s", i)
)
);
}
factories = ingestAggregatorFactories.toArray(new AggregatorFactory[0]);
}
private static final class MapIncrementalIndex extends OnheapIncrementalIndex
{
private final AtomicInteger indexIncrement = new AtomicInteger(0);
ConcurrentHashMap<Integer, Aggregator[]> indexedMap = new ConcurrentHashMap<Integer, Aggregator[]>();
public MapIncrementalIndex(
long minTimestamp,
QueryGranularity gran,
AggregatorFactory[] metrics,
int maxRowCount
)
{
super(minTimestamp, gran, metrics, maxRowCount);
}
@Override
protected Aggregator[] concurrentGet(int offset)
{
// All get operations should be fine
return indexedMap.get(offset);
}
@Override
protected void concurrentSet(int offset, Aggregator[] value)
{
indexedMap.put(offset, value);
}
@Override
protected Integer addToFacts(
AggregatorFactory[] metrics,
boolean deserializeComplexMetrics,
InputRow row,
AtomicInteger numEntries,
TimeAndDims key,
ThreadLocal<InputRow> in
) throws IndexSizeExceededException
{
final Integer priorIdex = getFacts().get(key);
Aggregator[] aggs;
if (null != priorIdex) {
aggs = indexedMap.get(priorIdex);
} else {
aggs = new Aggregator[metrics.length];
for (int i = 0; i < metrics.length; i++) {
final AggregatorFactory agg = metrics[i];
aggs[i] = agg.factorize(
makeColumnSelectorFactory(agg, in, deserializeComplexMetrics)
);
}
Integer rowIndex;
do {
rowIndex = indexIncrement.incrementAndGet();
} while (null != indexedMap.putIfAbsent(rowIndex, aggs));
// Last ditch sanity checks
if (numEntries.get() >= maxRowCount && !getFacts().containsKey(key)) {
throw new IndexSizeExceededException("Maximum number of rows reached");
}
final Integer prev = getFacts().putIfAbsent(key, rowIndex);
if (null == prev) {
numEntries.incrementAndGet();
} else {
// We lost a race
aggs = indexedMap.get(prev);
// Free up the misfire
indexedMap.remove(rowIndex);
// This is expected to occur ~80% of the time in the worst scenarios
}
}
in.set(row);
for (Aggregator agg : aggs) {
synchronized (agg) {
agg.aggregate();
}
}
in.set(null);
return numEntries.get();
}
}
@Parameterized.Parameters
public static Collection<Object[]> getParameters()
{
return ImmutableList.<Object[]>of(
new Object[]{OnheapIncrementalIndex.class},
new Object[]{MapIncrementalIndex.class}
);
}
private final Class<? extends OnheapIncrementalIndex> incrementalIndex;
public OnheapIncrementalIndexBenchmark(Class<? extends OnheapIncrementalIndex> incrementalIndex)
{
this.incrementalIndex = incrementalIndex;
}
private static MapBasedInputRow getLongRow(long timestamp, int rowID, int dimensionCount)
{
List<String> dimensionList = new ArrayList<String>(dimensionCount);
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
for (int i = 0; i < dimensionCount; i++) {
String dimName = String.format("Dim_%d", i);
dimensionList.add(dimName);
builder.put(dimName, new Integer(rowID).longValue());
}
return new MapBasedInputRow(timestamp, dimensionList, builder.build());
}
@Test
@BenchmarkOptions(callgc = true, clock = Clock.REAL_TIME, warmupRounds = 10, benchmarkRounds = 20)
public void testConcurrentAddRead()
throws InterruptedException, ExecutionException, NoSuchMethodException, IllegalAccessException,
InvocationTargetException, InstantiationException
{
final int taskCount = 30;
final int concurrentThreads = 3;
final int elementsPerThread = 1 << 15;
final OnheapIncrementalIndex incrementalIndex = this.incrementalIndex.getConstructor(
Long.TYPE,
QueryGranularity.class,
AggregatorFactory[].class,
Integer.TYPE
).newInstance(0, QueryGranularity.NONE, factories, elementsPerThread * taskCount);
final ArrayList<AggregatorFactory> queryAggregatorFactories = new ArrayList<>(dimensionCount + 1);
queryAggregatorFactories.add(new CountAggregatorFactory("rows"));
for (int i = 0; i < dimensionCount; ++i) {
queryAggregatorFactories.add(
new LongSumAggregatorFactory(
String.format("sumResult%s", i),
String.format("sumResult%s", i)
)
);
queryAggregatorFactories.add(
new DoubleSumAggregatorFactory(
String.format("doubleSumResult%s", i),
String.format("doubleSumResult%s", i)
)
);
}
final ListeningExecutorService indexExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("index-executor-%d")
.setPriority(Thread.MIN_PRIORITY)
.build()
)
);
final ListeningExecutorService queryExecutor = MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
concurrentThreads,
new ThreadFactoryBuilder()
.setDaemon(false)
.setNameFormat("query-executor-%d")
.build()
)
);
final long timestamp = System.currentTimeMillis();
final Interval queryInterval = new Interval("1900-01-01T00:00:00Z/2900-01-01T00:00:00Z");
final List<ListenableFuture<?>> indexFutures = new LinkedList<>();
final List<ListenableFuture<?>> queryFutures = new LinkedList<>();
final Segment incrementalIndexSegment = new IncrementalIndexSegment(incrementalIndex, null);
final QueryRunnerFactory factory = new TimeseriesQueryRunnerFactory(
new TimeseriesQueryQueryToolChest(new QueryConfig()),
new TimeseriesQueryEngine(),
QueryRunnerTestHelper.NOOP_QUERYWATCHER
);
final AtomicInteger currentlyRunning = new AtomicInteger(0);
final AtomicBoolean concurrentlyRan = new AtomicBoolean(false);
final AtomicBoolean someoneRan = new AtomicBoolean(false);
for (int j = 0; j < taskCount; j++) {
indexFutures.add(
indexExecutor.submit(
new Runnable()
{
@Override
public void run()
{
currentlyRunning.incrementAndGet();
try {
for (int i = 0; i < elementsPerThread; i++) {
incrementalIndex.add(getLongRow(timestamp + i, 1, dimensionCount));
}
}
catch (IndexSizeExceededException e) {
throw Throwables.propagate(e);
}
currentlyRunning.decrementAndGet();
someoneRan.set(true);
}
}
)
);
queryFutures.add(
queryExecutor.submit(
new Runnable()
{
@Override
public void run()
{
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(QueryGranularity.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
Map<String, Object> context = new HashMap<String, Object>();
for (Result<TimeseriesResultValue> result :
Sequences.toList(
runner.run(query, context),
new LinkedList<Result<TimeseriesResultValue>>()
)
) {
if (someoneRan.get()) {
Assert.assertTrue(result.getValue().getDoubleMetric("doubleSumResult0") > 0);
}
}
if (currentlyRunning.get() > 0) {
concurrentlyRan.set(true);
}
}
}
)
);
}
List<ListenableFuture<?>> allFutures = new ArrayList<>(queryFutures.size() + indexFutures.size());
allFutures.addAll(queryFutures);
allFutures.addAll(indexFutures);
Futures.allAsList(allFutures).get();
//Assert.assertTrue("Did not hit concurrency, please try again", concurrentlyRan.get());
queryExecutor.shutdown();
indexExecutor.shutdown();
QueryRunner<Result<TimeseriesResultValue>> runner = new FinalizeResultsQueryRunner<Result<TimeseriesResultValue>>(
factory.createRunner(incrementalIndexSegment),
factory.getToolchest()
);
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource("xxx")
.granularity(QueryGranularity.ALL)
.intervals(ImmutableList.of(queryInterval))
.aggregators(queryAggregatorFactories)
.build();
Map<String, Object> context = new HashMap<String, Object>();
List<Result<TimeseriesResultValue>> results = Sequences.toList(
runner.run(query, context),
new LinkedList<Result<TimeseriesResultValue>>()
);
final int expectedVal = elementsPerThread * taskCount;
for (Result<TimeseriesResultValue> result : results) {
Assert.assertEquals(elementsPerThread, result.getValue().getLongMetric("rows").intValue());
for (int i = 0; i < dimensionCount; ++i) {
Assert.assertEquals(
String.format("Failed long sum on dimension %d", i),
expectedVal,
result.getValue().getLongMetric(String.format("sumResult%s", i)).intValue()
);
Assert.assertEquals(
String.format("Failed double sum on dimension %d", i),
expectedVal,
result.getValue().getDoubleMetric(String.format("doubleSumResult%s", i)).intValue()
);
}
}
}
}