diff --git a/common/src/main/java/io/druid/common/guava/CombiningSequence.java b/common/src/main/java/io/druid/common/guava/CombiningSequence.java index 4212e6564cd..2cc4e4b8c52 100644 --- a/common/src/main/java/io/druid/common/guava/CombiningSequence.java +++ b/common/src/main/java/io/druid/common/guava/CombiningSequence.java @@ -73,38 +73,57 @@ public class CombiningSequence implements Sequence final CombiningYieldingAccumulator combiningAccumulator = new CombiningYieldingAccumulator( ordering, mergeFn, accumulator ); + + combiningAccumulator.setRetVal(initValue); Yielder baseYielder = baseSequence.toYielder(null, combiningAccumulator); - if (baseYielder.isDone()) { - return Yielders.done(initValue, baseYielder); - } - - return makeYielder(baseYielder, combiningAccumulator); + return makeYielder(baseYielder, combiningAccumulator, false); } public Yielder makeYielder( - final Yielder yielder, final CombiningYieldingAccumulator combiningAccumulator + Yielder yielder, + final CombiningYieldingAccumulator combiningAccumulator, + boolean finalValue ) { + final Yielder finalYielder; + final OutType retVal; + final boolean finalFinalValue; + + if(!yielder.isDone()) { + retVal = combiningAccumulator.getRetVal(); + finalYielder = yielder.next(yielder.get()); + finalFinalValue = false; + } else { + if(!finalValue && combiningAccumulator.accumulatedSomething()) { + combiningAccumulator.accumulateLastValue(); + retVal = combiningAccumulator.getRetVal(); + finalFinalValue = true; + + if(!combiningAccumulator.yielded()) { + return Yielders.done(null, yielder); + } else { + finalYielder = Yielders.done(null, yielder); + } + } + else { + return Yielders.done(null, yielder); + } + } + + return new Yielder() { @Override public OutType get() { - return combiningAccumulator.getRetVal(); + return retVal; } @Override - public Yielder next(OutType outType) + public Yielder next(OutType initValue) { - T nextIn = yielder.get(); - combiningAccumulator.setRetVal(outType); - final Yielder baseYielder = yielder.next(nextIn); - if (baseYielder.isDone()) { - final OutType outValue = combiningAccumulator.getAccumulator().accumulate(outType, baseYielder.get()); - return Yielders.done(outValue, baseYielder); - } - return makeYielder(baseYielder, combiningAccumulator); + return makeYielder(finalYielder, combiningAccumulator, finalFinalValue); } @Override @@ -116,7 +135,7 @@ public class CombiningSequence implements Sequence @Override public void close() throws IOException { - yielder.close(); + finalYielder.close(); } }; } @@ -128,6 +147,8 @@ public class CombiningSequence implements Sequence private final YieldingAccumulator accumulator; private volatile OutType retVal; + private volatile T lastMergedVal; + private volatile boolean accumulatedSomething = false; public CombiningYieldingAccumulator( Ordering ordering, @@ -173,17 +194,34 @@ public class CombiningSequence implements Sequence @Override public T accumulate(T prevValue, T t) { + if (!accumulatedSomething) { + accumulatedSomething = true; + } + if (prevValue == null) { - return mergeFn.apply(t, null); + lastMergedVal = mergeFn.apply(t, null); + return lastMergedVal; } if (ordering.compare(prevValue, t) == 0) { - return mergeFn.apply(prevValue, t); + lastMergedVal = mergeFn.apply(prevValue, t); + return lastMergedVal; } + lastMergedVal = t; retVal = accumulator.accumulate(retVal, prevValue); return t; } + + public void accumulateLastValue() + { + retVal = accumulator.accumulate(retVal, lastMergedVal); + } + + public boolean accumulatedSomething() + { + return accumulatedSomething; + } } private class CombiningAccumulator implements Accumulator diff --git a/common/src/test/java/io/druid/common/guava/CombiningSequenceTest.java b/common/src/test/java/io/druid/common/guava/CombiningSequenceTest.java index 3b997f4cdcf..8e100113e5f 100644 --- a/common/src/test/java/io/druid/common/guava/CombiningSequenceTest.java +++ b/common/src/test/java/io/druid/common/guava/CombiningSequenceTest.java @@ -19,6 +19,8 @@ package io.druid.common.guava; +import com.google.common.base.Predicate; +import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import com.google.common.collect.Ordering; import com.metamx.common.Pair; @@ -29,16 +31,34 @@ import com.metamx.common.guava.YieldingAccumulator; import com.metamx.common.guava.nary.BinaryFn; import junit.framework.Assert; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import javax.annotation.Nullable; import java.io.IOException; import java.util.Arrays; +import java.util.Collection; import java.util.Iterator; import java.util.List; -/** - */ +@RunWith(Parameterized.class) public class CombiningSequenceTest { + @Parameterized.Parameters + public static Collection valuesToTry() + { + return Arrays.asList(new Object[][] { + {1}, {2}, {3}, {4}, {5}, {1000} + }); + } + + private final int yieldEvery; + + public CombiningSequenceTest(int yieldEvery) + { + this.yieldEvery = yieldEvery; + } + @Test public void testMerge() throws IOException { @@ -65,6 +85,75 @@ public class CombiningSequenceTest testCombining(pairs, expected); } + @Test + public void testNoMergeOne() throws IOException + { + List> pairs = Arrays.asList( + Pair.of(0, 1) + ); + + List> expected = Arrays.asList( + Pair.of(0, 1) + ); + + testCombining(pairs, expected); + } + + @Test + public void testMergeMany() throws IOException + { + List> pairs = Arrays.asList( + Pair.of(0, 6), + Pair.of(1, 1), + Pair.of(2, 1), + Pair.of(5, 11), + Pair.of(6, 1), + Pair.of(5, 1) + ); + + List> expected = Arrays.asList( + Pair.of(0, 6), + Pair.of(1, 1), + Pair.of(2, 1), + Pair.of(5, 11), + Pair.of(6, 1), + Pair.of(5, 1) + ); + + testCombining(pairs, expected); + } + + @Test + public void testNoMergeTwo() throws IOException + { + List> pairs = Arrays.asList( + Pair.of(0, 1), + Pair.of(1, 1) + ); + + List> expected = Arrays.asList( + Pair.of(0, 1), + Pair.of(1, 1) + ); + + testCombining(pairs, expected); + } + + @Test + public void testMergeTwo() throws IOException + { + List> pairs = Arrays.asList( + Pair.of(0, 1), + Pair.of(0, 1) + ); + + List> expected = Arrays.asList( + Pair.of(0, 2) + ); + + testCombining(pairs, expected); + } + @Test public void testMergeSomeThingsMergedAtEnd() throws IOException { @@ -136,28 +225,50 @@ public class CombiningSequenceTest null, new YieldingAccumulator, Pair>() { + int count = 0; + @Override public Pair accumulate( Pair lhs, Pair rhs ) { - yield(); + count++; + if(count % yieldEvery == 0) yield(); return rhs; } } ); - Iterator> expectedVals = expected.iterator(); + Iterator> expectedVals = Iterators.filter( + expected.iterator(), + new Predicate>() + { + int count = 0; + + @Override + public boolean apply( + @Nullable Pair input + ) + { + count++; + if (count % yieldEvery == 0) { + return true; + } + return false; + } + } + ); if (expectedVals.hasNext()) { while (!yielder.isDone()) { - final Pair nextVal = expectedVals.next(); - Assert.assertEquals(nextVal, yielder.get()); - yielder = yielder.next(null); + final Pair expectedVal = expectedVals.next(); + final Pair actual = yielder.get(); + Assert.assertEquals(expectedVal, actual); + yielder = yielder.next(actual); } - Assert.assertEquals(expectedVals.next(), yielder.get()); } Assert.assertTrue(yielder.isDone()); + Assert.assertFalse(expectedVals.hasNext()); yielder.close(); } }