proper query cancellation tests

This commit is contained in:
Xavier Léauté 2014-06-02 17:39:08 -07:00
parent 855c66c9ad
commit d0f9c438f8
6 changed files with 107 additions and 75 deletions

View File

@ -20,7 +20,6 @@
package io.druid.query;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.util.concurrent.ListenableFuture;
import com.metamx.common.concurrent.ExecutorServiceConfig;
@ -29,18 +28,22 @@ import com.metamx.common.guava.Sequences;
import com.metamx.common.lifecycle.Lifecycle;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import org.junit.Ignore;
import org.easymock.Capture;
import org.easymock.EasyMock;
import org.easymock.IAnswer;
import org.junit.Assert;
import org.junit.Test;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
public class ChainedExecutionQueryRunnerTest
{
@Test @Ignore
@Test
public void testQueryCancellation() throws Exception
{
ExecutorService exec = PrioritizedExecutorService.create(
@ -63,25 +66,36 @@ public class ChainedExecutionQueryRunnerTest
final CountDownLatch queriesStarted = new CountDownLatch(2);
final CountDownLatch queryIsRegistered = new CountDownLatch(1);
final Map<Query, ListenableFuture> queries = Maps.newHashMap();
QueryWatcher watcher = new QueryWatcher()
{
@Override
public void registerQuery(Query query, ListenableFuture future)
{
queries.put(query, future);
queryIsRegistered.countDown();
}
};
Capture<ListenableFuture> capturedFuture = new Capture<>();
QueryWatcher watcher = EasyMock.createStrictMock(QueryWatcher.class);
watcher.registerQuery(EasyMock.<Query>anyObject(), EasyMock.and(EasyMock.<ListenableFuture>anyObject(), EasyMock.capture(capturedFuture)));
EasyMock.expectLastCall()
.andAnswer(
new IAnswer<Void>()
{
@Override
public Void answer() throws Throwable
{
queryIsRegistered.countDown();
return null;
}
}
)
.once();
EasyMock.replay(watcher);
DyingQueryRunner runner1 = new DyingQueryRunner(queriesStarted);
DyingQueryRunner runner2 = new DyingQueryRunner(queriesStarted);
DyingQueryRunner runner3 = new DyingQueryRunner(queriesStarted);
ChainedExecutionQueryRunner chainedRunner = new ChainedExecutionQueryRunner<>(
exec,
Ordering.<Integer>natural(),
watcher,
Lists.<QueryRunner<Integer>>newArrayList(
new DyingQueryRunner(1, queriesStarted),
new DyingQueryRunner(2, queriesStarted),
new DyingQueryRunner(3, queriesStarted)
runner1,
runner2,
runner3
)
);
@ -93,7 +107,7 @@ public class ChainedExecutionQueryRunnerTest
.build()
);
Future f = Executors.newFixedThreadPool(1).submit(
Future resultFuture = Executors.newFixedThreadPool(1).submit(
new Runnable()
{
@Override
@ -104,45 +118,64 @@ public class ChainedExecutionQueryRunnerTest
}
);
// wait for query to register
queryIsRegistered.await();
queriesStarted.await();
// wait for query to register and start
Assert.assertTrue(queryIsRegistered.await(1, TimeUnit.SECONDS));
Assert.assertTrue(queriesStarted.await(1, TimeUnit.SECONDS));
// cancel the query
queries.values().iterator().next().cancel(true);
f.get();
Assert.assertTrue(capturedFuture.hasCaptured());
ListenableFuture future = capturedFuture.getValue();
future.cancel(true);
QueryInterruptedException cause = null;
try {
resultFuture.get();
} catch(ExecutionException e) {
Assert.assertTrue(e.getCause() instanceof QueryInterruptedException);
cause = (QueryInterruptedException)e.getCause();
}
Assert.assertNotNull(cause);
Assert.assertTrue(future.isCancelled());
Assert.assertTrue(runner1.hasStarted);
Assert.assertTrue(runner2.hasStarted);
Assert.assertFalse(runner3.hasStarted);
Assert.assertFalse(runner1.hasCompleted);
Assert.assertFalse(runner2.hasCompleted);
Assert.assertFalse(runner3.hasCompleted);
EasyMock.verify(watcher);
}
private static class DyingQueryRunner implements QueryRunner<Integer>
{
private final int id;
private final CountDownLatch latch;
private boolean hasStarted = false;
private boolean hasCompleted = false;
public DyingQueryRunner(int id, CountDownLatch latch) {
this.id = id;
public DyingQueryRunner(CountDownLatch latch)
{
this.latch = latch;
}
@Override
public Sequence<Integer> run(Query<Integer> query)
{
hasStarted = true;
latch.countDown();
int i = 0;
while (i >= 0) {
if(Thread.interrupted()) {
throw new QueryInterruptedException("I got killed");
}
// do a lot of work
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new QueryInterruptedException("I got killed");
}
++i;
if (Thread.interrupted()) {
throw new QueryInterruptedException("I got killed");
}
return Sequences.simple(Lists.newArrayList(i));
// do a lot of work
try {
Thread.sleep(500);
}
catch (InterruptedException e) {
throw new QueryInterruptedException("I got killed");
}
hasCompleted = true;
return Sequences.simple(Lists.newArrayList(123));
}
}
}

View File

@ -21,6 +21,7 @@ package io.druid.query;
import com.google.common.base.Function;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ListenableFuture;
import io.druid.granularity.QueryGranularity;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
@ -53,6 +54,16 @@ import java.util.List;
*/
public class QueryRunnerTestHelper
{
public static final QueryWatcher DUMMY_QUERYWATCHER = new QueryWatcher()
{
@Override
public void registerQuery(Query query, ListenableFuture future)
{
}
};
public static final String segmentId = "testSegment";
public static final String dataSource = "testing";
public static final UnionDataSource unionDataSource = new UnionDataSource(

View File

@ -41,14 +41,11 @@ public class TestQueryRunners
Segment adapter
)
{
QueryRunnerFactory factory = new TopNQueryRunnerFactory(pool, new TopNQueryQueryToolChest(topNConfig), new QueryWatcher()
{
@Override
public void registerQuery(Query query, ListenableFuture future)
{
}
});
QueryRunnerFactory factory = new TopNQueryRunnerFactory(
pool,
new TopNQueryQueryToolChest(topNConfig),
QueryRunnerTestHelper.DUMMY_QUERYWATCHER
);
return new FinalizeResultsQueryRunner<T>(
factory.createRunner(adapter),
factory.getToolchest()

View File

@ -72,14 +72,7 @@ public class TopNQueryRunnerTest
new TopNQueryRunnerFactory(
TestQueryRunners.getPool(),
new TopNQueryQueryToolChest(new TopNQueryConfig()),
new QueryWatcher()
{
@Override
public void registerQuery(Query query, ListenableFuture future)
{
}
}
QueryRunnerTestHelper.DUMMY_QUERYWATCHER
)
)
);
@ -97,14 +90,7 @@ public class TopNQueryRunnerTest
}
),
new TopNQueryQueryToolChest(new TopNQueryConfig()),
new QueryWatcher()
{
@Override
public void registerQuery(Query query, ListenableFuture future)
{
}
}
QueryRunnerTestHelper.DUMMY_QUERYWATCHER
)
)
);

View File

@ -23,9 +23,12 @@ import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ListenableFuture;
import io.druid.collections.StupidPool;
import io.druid.query.Query;
import io.druid.query.QueryRunner;
import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.QueryWatcher;
import io.druid.query.Result;
import io.druid.query.TestQueryRunners;
import io.druid.query.aggregation.AggregatorFactory;
@ -65,7 +68,8 @@ public class TopNUnionQueryTest
QueryRunnerTestHelper.makeUnionQueryRunners(
new TopNQueryRunnerFactory(
TestQueryRunners.getPool(),
new TopNQueryQueryToolChest(new TopNQueryConfig())
new TopNQueryQueryToolChest(new TopNQueryConfig()),
QueryRunnerTestHelper.DUMMY_QUERYWATCHER
)
)
);
@ -82,7 +86,8 @@ public class TopNUnionQueryTest
}
}
),
new TopNQueryQueryToolChest(new TopNQueryConfig())
new TopNQueryQueryToolChest(new TopNQueryConfig()),
QueryRunnerTestHelper.DUMMY_QUERYWATCHER
)
)
);

View File

@ -44,16 +44,14 @@ import io.druid.query.Result;
import io.druid.query.timeboundary.TimeBoundaryQuery;
import io.druid.timeline.DataSegment;
import io.druid.timeline.partition.NoneShardSpec;
import junit.framework.Assert;
import org.easymock.EasyMock;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.timeout.ReadTimeoutException;
import org.joda.time.DateTime;
import org.joda.time.Interval;
import org.junit.Rule;
import org.junit.Assert;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
@ -71,9 +69,6 @@ public class DirectDruidClientTest
}
};
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void testRun() throws Exception
{
@ -220,8 +215,13 @@ public class DirectDruidClientTest
Assert.assertEquals(0, client1.getNumOpenConnections());
thrown.expect(QueryInterruptedException.class);
Assert.assertTrue(Sequences.toList(results, Lists.newArrayList()).isEmpty());
QueryInterruptedException exception = null;
try {
Sequences.toList(results, Lists.newArrayList());
} catch(QueryInterruptedException e) {
exception = e;
}
Assert.assertNotNull(exception);
EasyMock.verify(httpClient);
}