fix broken CombiningSequence.toYielder behavior

This commit is contained in:
Xavier Léauté 2014-05-13 16:40:54 -07:00
parent c57a18d6b6
commit 62c55eaf26
2 changed files with 176 additions and 27 deletions

View File

@ -73,38 +73,57 @@ public class CombiningSequence<T> implements Sequence<T>
final CombiningYieldingAccumulator<OutType, T> combiningAccumulator = new CombiningYieldingAccumulator<OutType, T>(
ordering, mergeFn, accumulator
);
combiningAccumulator.setRetVal(initValue);
Yielder<T> baseYielder = baseSequence.toYielder(null, combiningAccumulator);
if (baseYielder.isDone()) {
return Yielders.done(initValue, baseYielder);
}
return makeYielder(baseYielder, combiningAccumulator);
return makeYielder(baseYielder, combiningAccumulator, false);
}
public <OutType, T> Yielder<OutType> makeYielder(
final Yielder<T> yielder, final CombiningYieldingAccumulator<OutType, T> combiningAccumulator
Yielder<T> yielder,
final CombiningYieldingAccumulator<OutType, T> combiningAccumulator,
boolean finalValue
)
{
final Yielder<T> finalYielder;
final OutType retVal;
final boolean finalFinalValue;
if(!yielder.isDone()) {
retVal = combiningAccumulator.getRetVal();
finalYielder = yielder.next(yielder.get());
finalFinalValue = false;
} else {
if(!finalValue && combiningAccumulator.accumulatedSomething()) {
combiningAccumulator.accumulateLastValue();
retVal = combiningAccumulator.getRetVal();
finalFinalValue = true;
if(!combiningAccumulator.yielded()) {
return Yielders.done(null, yielder);
} else {
finalYielder = Yielders.done(null, yielder);
}
}
else {
return Yielders.done(null, yielder);
}
}
return new Yielder<OutType>()
{
@Override
public OutType get()
{
return combiningAccumulator.getRetVal();
return retVal;
}
@Override
public Yielder<OutType> next(OutType outType)
public Yielder<OutType> next(OutType initValue)
{
T nextIn = yielder.get();
combiningAccumulator.setRetVal(outType);
final Yielder<T> baseYielder = yielder.next(nextIn);
if (baseYielder.isDone()) {
final OutType outValue = combiningAccumulator.getAccumulator().accumulate(outType, baseYielder.get());
return Yielders.done(outValue, baseYielder);
}
return makeYielder(baseYielder, combiningAccumulator);
return makeYielder(finalYielder, combiningAccumulator, finalFinalValue);
}
@Override
@ -116,7 +135,7 @@ public class CombiningSequence<T> implements Sequence<T>
@Override
public void close() throws IOException
{
yielder.close();
finalYielder.close();
}
};
}
@ -128,6 +147,8 @@ public class CombiningSequence<T> implements Sequence<T>
private final YieldingAccumulator<OutType, T> accumulator;
private volatile OutType retVal;
private volatile T lastMergedVal;
private volatile boolean accumulatedSomething = false;
public CombiningYieldingAccumulator(
Ordering<T> ordering,
@ -173,17 +194,34 @@ public class CombiningSequence<T> implements Sequence<T>
@Override
public T accumulate(T prevValue, T t)
{
if (!accumulatedSomething) {
accumulatedSomething = true;
}
if (prevValue == null) {
return mergeFn.apply(t, null);
lastMergedVal = mergeFn.apply(t, null);
return lastMergedVal;
}
if (ordering.compare(prevValue, t) == 0) {
return mergeFn.apply(prevValue, t);
lastMergedVal = mergeFn.apply(prevValue, t);
return lastMergedVal;
}
lastMergedVal = t;
retVal = accumulator.accumulate(retVal, prevValue);
return t;
}
public void accumulateLastValue()
{
retVal = accumulator.accumulate(retVal, lastMergedVal);
}
public boolean accumulatedSomething()
{
return accumulatedSomething;
}
}
private class CombiningAccumulator<OutType> implements Accumulator<T, T>

View File

@ -19,6 +19,8 @@
package io.druid.common.guava;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.metamx.common.Pair;
@ -29,16 +31,34 @@ import com.metamx.common.guava.YieldingAccumulator;
import com.metamx.common.guava.nary.BinaryFn;
import junit.framework.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
/**
*/
@RunWith(Parameterized.class)
public class CombiningSequenceTest
{
@Parameterized.Parameters
public static Collection<Object[]> valuesToTry()
{
return Arrays.asList(new Object[][] {
{1}, {2}, {3}, {4}, {5}, {1000}
});
}
private final int yieldEvery;
public CombiningSequenceTest(int yieldEvery)
{
this.yieldEvery = yieldEvery;
}
@Test
public void testMerge() throws IOException
{
@ -65,6 +85,75 @@ public class CombiningSequenceTest
testCombining(pairs, expected);
}
@Test
public void testNoMergeOne() throws IOException
{
List<Pair<Integer, Integer>> pairs = Arrays.asList(
Pair.of(0, 1)
);
List<Pair<Integer, Integer>> expected = Arrays.asList(
Pair.of(0, 1)
);
testCombining(pairs, expected);
}
@Test
public void testMergeMany() throws IOException
{
List<Pair<Integer, Integer>> pairs = Arrays.asList(
Pair.of(0, 6),
Pair.of(1, 1),
Pair.of(2, 1),
Pair.of(5, 11),
Pair.of(6, 1),
Pair.of(5, 1)
);
List<Pair<Integer, Integer>> expected = Arrays.asList(
Pair.of(0, 6),
Pair.of(1, 1),
Pair.of(2, 1),
Pair.of(5, 11),
Pair.of(6, 1),
Pair.of(5, 1)
);
testCombining(pairs, expected);
}
@Test
public void testNoMergeTwo() throws IOException
{
List<Pair<Integer, Integer>> pairs = Arrays.asList(
Pair.of(0, 1),
Pair.of(1, 1)
);
List<Pair<Integer, Integer>> expected = Arrays.asList(
Pair.of(0, 1),
Pair.of(1, 1)
);
testCombining(pairs, expected);
}
@Test
public void testMergeTwo() throws IOException
{
List<Pair<Integer, Integer>> pairs = Arrays.asList(
Pair.of(0, 1),
Pair.of(0, 1)
);
List<Pair<Integer, Integer>> expected = Arrays.asList(
Pair.of(0, 2)
);
testCombining(pairs, expected);
}
@Test
public void testMergeSomeThingsMergedAtEnd() throws IOException
{
@ -136,28 +225,50 @@ public class CombiningSequenceTest
null,
new YieldingAccumulator<Pair<Integer, Integer>, Pair<Integer, Integer>>()
{
int count = 0;
@Override
public Pair<Integer, Integer> accumulate(
Pair<Integer, Integer> lhs, Pair<Integer, Integer> rhs
)
{
yield();
count++;
if(count % yieldEvery == 0) yield();
return rhs;
}
}
);
Iterator<Pair<Integer, Integer>> expectedVals = expected.iterator();
Iterator<Pair<Integer, Integer>> expectedVals = Iterators.filter(
expected.iterator(),
new Predicate<Pair<Integer, Integer>>()
{
int count = 0;
@Override
public boolean apply(
@Nullable Pair<Integer, Integer> input
)
{
count++;
if (count % yieldEvery == 0) {
return true;
}
return false;
}
}
);
if (expectedVals.hasNext()) {
while (!yielder.isDone()) {
final Pair<Integer, Integer> nextVal = expectedVals.next();
Assert.assertEquals(nextVal, yielder.get());
yielder = yielder.next(null);
final Pair<Integer, Integer> expectedVal = expectedVals.next();
final Pair<Integer, Integer> actual = yielder.get();
Assert.assertEquals(expectedVal, actual);
yielder = yielder.next(actual);
}
Assert.assertEquals(expectedVals.next(), yielder.get());
}
Assert.assertTrue(yielder.isDone());
Assert.assertFalse(expectedVals.hasNext());
yielder.close();
}
}