diff --git a/processing/src/main/java/io/druid/segment/incremental/IncrementalIndex.java b/processing/src/main/java/io/druid/segment/incremental/IncrementalIndex.java index 9866618df7e..057a4b95ff2 100644 --- a/processing/src/main/java/io/druid/segment/incremental/IncrementalIndex.java +++ b/processing/src/main/java/io/druid/segment/incremental/IncrementalIndex.java @@ -152,14 +152,14 @@ public class IncrementalIndex implements Iterable } final List rowDimensions = row.getDimensions(); - String[][] dims = new String[dimensionOrder.size()][]; + String[][] dims; List overflow = null; - for (String dimension : rowDimensions) { - dimension = dimension.toLowerCase(); - List dimensionValues = row.getDimension(dimension); - - synchronized (dimensionOrder) { + synchronized (dimensionOrder) { + dims = new String[dimensionOrder.size()][]; + for (String dimension : rowDimensions) { + dimension = dimension.toLowerCase(); + List dimensionValues = row.getDimension(dimension); Integer index = dimensionOrder.get(dimension); if (index == null) { dimensionOrder.put(dimension, dimensionOrder.size()); @@ -175,6 +175,7 @@ public class IncrementalIndex implements Iterable } } + if (overflow != null) { // Merge overflow and non-overflow String[][] newDims = new String[dims.length + overflow.size()][]; @@ -287,8 +288,9 @@ public class IncrementalIndex implements Iterable Aggregator[] prev = facts.putIfAbsent(key, aggs); if (prev != null) { aggs = prev; + } else { + numEntries.incrementAndGet(); } - numEntries.incrementAndGet(); } synchronized (this) { diff --git a/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java b/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java index 5a7b448a4a8..35fb2b81c0e 100644 --- a/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java +++ b/processing/src/test/java/io/druid/segment/data/IncrementalIndexTest.java @@ -24,18 +24,59 @@ import io.druid.data.input.MapBasedInputRow; import io.druid.data.input.Row; import io.druid.granularity.QueryGranularity; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.CountAggregatorFactory; import io.druid.segment.incremental.IncrementalIndex; import junit.framework.Assert; import org.junit.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; /** */ public class IncrementalIndexTest { + public static IncrementalIndex createCaseInsensitiveIndex(long timestamp) + { + IncrementalIndex index = new IncrementalIndex(0L, QueryGranularity.NONE, new AggregatorFactory[]{}); + + index.add( + new MapBasedInputRow( + timestamp, + Arrays.asList("Dim1", "DiM2"), + ImmutableMap.of("dim1", "1", "dim2", "2", "DIM1", "3", "dIM2", "4") + ) + ); + + index.add( + new MapBasedInputRow( + timestamp, + Arrays.asList("diM1", "dIM2"), + ImmutableMap.of("Dim1", "1", "DiM2", "2", "dim1", "3", "dim2", "4") + ) + ); + return index; + } + + public static MapBasedInputRow getRow(long timestamp, int rowID, int dimensionCount) + { + List dimensionList = new ArrayList(dimensionCount); + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (int i = 0; i < dimensionCount; i++) { + String dimName = String.format("Dim_%d", i); + dimensionList.add(dimName); + builder.put(dimName, dimName + rowID); + } + return new MapBasedInputRow(timestamp, dimensionList, builder.build()); + } + @Test public void testCaseInsensitivity() throws Exception { @@ -58,25 +99,52 @@ public class IncrementalIndexTest Assert.assertEquals(Arrays.asList("4"), row.getDimension("dim2")); } - public static IncrementalIndex createCaseInsensitiveIndex(long timestamp) + @Test + public void testConcurrentAdd() throws Exception { - IncrementalIndex index = new IncrementalIndex(0L, QueryGranularity.NONE, new AggregatorFactory[]{}); - - index.add( - new MapBasedInputRow( - timestamp, - Arrays.asList("Dim1", "DiM2"), - ImmutableMap.of("dim1", "1", "dim2", "2", "DIM1", "3", "dIM2", "4") - ) + final IncrementalIndex index = new IncrementalIndex( + 0L, + QueryGranularity.NONE, + new AggregatorFactory[]{new CountAggregatorFactory("count")} ); + final int threadCount = 10; + final int elementsPerThread = 200; + final int dimensionCount = 5; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + final long timestamp = System.currentTimeMillis(); + final CountDownLatch latch = new CountDownLatch(threadCount); + for (int j = 0; j < threadCount; j++) { + executor.submit( + new Runnable() + { + @Override + public void run() + { + try { + for (int i = 0; i < elementsPerThread; i++) { + index.add(getRow(timestamp + i, i, dimensionCount)); + } + } + catch (Exception e) { + e.printStackTrace(); + } + latch.countDown(); + } + } + ); + } + Assert.assertTrue(latch.await(60, TimeUnit.SECONDS)); - index.add( - new MapBasedInputRow( - timestamp, - Arrays.asList("diM1", "dIM2"), - ImmutableMap.of("Dim1", "1", "DiM2", "2", "dim1", "3", "dim2", "4") - ) - ); - return index; + Assert.assertEquals(dimensionCount, index.getDimensions().size()); + Assert.assertEquals(elementsPerThread, index.size()); + Iterator iterator = index.iterator(); + int curr = 0; + while (iterator.hasNext()) { + Row row = iterator.next(); + Assert.assertEquals(timestamp + curr, row.getTimestampFromEpoch()); + Assert.assertEquals(Float.valueOf(threadCount), row.getFloatMetric("count")); + curr++; + } + Assert.assertEquals(elementsPerThread, curr); } }