diff --git a/processing/src/main/java/io/druid/query/aggregation/Aggregator.java b/processing/src/main/java/io/druid/query/aggregation/Aggregator.java index 95495ddea53..5a412fa1f55 100644 --- a/processing/src/main/java/io/druid/query/aggregation/Aggregator.java +++ b/processing/src/main/java/io/druid/query/aggregation/Aggregator.java @@ -19,6 +19,8 @@ package io.druid.query.aggregation; +import java.io.Closeable; + /** * An Aggregator is an object that can aggregate metrics. Its aggregation-related methods (namely, aggregate() and get()) * do not take any arguments as the assumption is that the Aggregator was given something in its constructor that @@ -32,7 +34,8 @@ package io.druid.query.aggregation; * * This interface is old and going away. It is being replaced by BufferAggregator */ -public interface Aggregator { +public interface Aggregator extends Closeable +{ void aggregate(); void reset(); Object get(); diff --git a/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java b/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java index d5be53ae9a6..7d9514d85f0 100644 --- a/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java +++ b/processing/src/main/java/io/druid/segment/incremental/OnheapIncrementalIndex.java @@ -20,7 +20,9 @@ package io.druid.segment.incremental; import com.google.common.base.Supplier; +import com.google.common.base.Throwables; import com.google.common.collect.Maps; +import com.google.common.io.Closer; import io.druid.data.input.InputRow; import io.druid.data.input.impl.DimensionsSpec; import io.druid.granularity.QueryGranularity; @@ -37,6 +39,7 @@ import io.druid.segment.ObjectColumnSelector; import io.druid.segment.column.ColumnCapabilities; import javax.annotation.Nullable; +import java.io.IOException; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -257,6 +260,23 @@ public class OnheapIncrementalIndex extends IncrementalIndex rowContainer.set(null); } + private void closeAggregators() + { + Closer closer = Closer.create(); + for (Aggregator[] aggs : aggregators.values()) { + for (Aggregator agg : aggs) { + closer.register(agg); + } + } + + try { + closer.close(); + } + catch (IOException e) { + Throwables.propagate(e); + } + } + protected Aggregator[] concurrentGet(int offset) { // All get operations should be fine @@ -327,6 +347,7 @@ public class OnheapIncrementalIndex extends IncrementalIndex public void close() { super.close(); + closeAggregators(); aggregators.clear(); facts.clear(); if (selectors != null) { diff --git a/processing/src/test/java/io/druid/segment/incremental/OnheapIncrementalIndexTest.java b/processing/src/test/java/io/druid/segment/incremental/OnheapIncrementalIndexTest.java index d4a864c91f1..318ee388364 100644 --- a/processing/src/test/java/io/druid/segment/incremental/OnheapIncrementalIndexTest.java +++ b/processing/src/test/java/io/druid/segment/incremental/OnheapIncrementalIndexTest.java @@ -23,8 +23,11 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import io.druid.data.input.MapBasedInputRow; import io.druid.granularity.QueryGranularities; +import io.druid.query.aggregation.Aggregator; import io.druid.query.aggregation.AggregatorFactory; +import io.druid.query.aggregation.LongMaxAggregator; import io.druid.query.aggregation.LongMaxAggregatorFactory; +import org.easymock.EasyMock; import org.junit.Assert; import org.junit.Test; @@ -97,4 +100,34 @@ public class OnheapIncrementalIndexTest Assert.assertEquals(0, checkFailedCount.get()); } + + @Test + public void testOnHeapIncrementalIndexClose() throws Exception + { + // Prepare the mocks & set close() call count expectation to 1 + Aggregator mockedAggregator = EasyMock.createMock(LongMaxAggregator.class); + mockedAggregator.close(); + EasyMock.expectLastCall().times(1); + + final OnheapIncrementalIndex index = new OnheapIncrementalIndex( + 0, + QueryGranularities.MINUTE, + new AggregatorFactory[]{new LongMaxAggregatorFactory("max", "max")}, + MAX_ROWS + ); + + index.add(new MapBasedInputRow( + 0, + Lists.newArrayList("billy"), + ImmutableMap.of("billy", 1, "max", 1) + )); + + // override the aggregators with the mocks + index.concurrentGet(0)[0] = mockedAggregator; + + // close the indexer and validate the expectations + EasyMock.replay(mockedAggregator); + index.close(); + EasyMock.verify(mockedAggregator); + } }