Fix exception when using complex aggs with result level caching (#7614)

* Fix exception when using complex aggs with result level caching

* Add test comments

* checkstyle

* Add helper function for getting aggs from cache

* Move method to CacheStrategy

* Revert QueryToolChest changes

* Update test comments
This commit is contained in:
Jonathan Wei 2019-05-09 13:49:11 -07:00 committed by GitHub
parent 2ac112151f
commit 1b577c9b1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 312 additions and 22 deletions

View File

@ -22,8 +22,11 @@ package org.apache.druid.query;
import com.fasterxml.jackson.core.type.TypeReference;
import com.google.common.base.Function;
import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.query.aggregation.AggregatorFactory;
import java.util.Iterator;
import java.util.concurrent.ExecutorService;
import java.util.function.BiFunction;
/**
*/
@ -98,4 +101,31 @@ public interface CacheStrategy<T, CacheType, QueryType extends Query<T>>
{
return pullFromCache(false);
}
/**
* Helper function used by TopN, GroupBy, Timeseries queries in {@link #pullFromCache(boolean)}.
* When using the result level cache, the agg values seen here are
* finalized values generated by AggregatorFactory.finalizeComputation().
* These finalized values are deserialized from the cache as generic Objects, which will
* later be reserialized and returned to the user without further modification.
* Because the agg values are deserialized as generic Objects, the values are subject to the same
* type consistency issues handled by DimensionHandlerUtils.convertObjectToType() in the pullFromCache implementations
* for dimension values (e.g., a Float would become Double).
*/
static void fetchAggregatorsFromCache(
Iterator<AggregatorFactory> aggIter,
Iterator<Object> resultIter,
boolean isResultLevelCache,
BiFunction<String, Object, Void> addToResultFunction
)
{
while (aggIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggIter.next();
if (isResultLevelCache) {
addToResultFunction.apply(factory.getName(), resultIter.next());
} else {
addToResultFunction.apply(factory.getName(), factory.deserialize(resultIter.next()));
}
}
}
}

View File

@ -555,7 +555,7 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<Row, GroupByQuery
DateTime timestamp = granularity.toDateTime(((Number) results.next()).longValue());
Map<String, Object> event = Maps.newLinkedHashMap();
final Map<String, Object> event = Maps.newLinkedHashMap();
Iterator<DimensionSpec> dimsIter = dims.iterator();
while (dimsIter.hasNext() && results.hasNext()) {
final DimensionSpec dimensionSpec = dimsIter.next();
@ -566,12 +566,18 @@ public class GroupByQueryQueryToolChest extends QueryToolChest<Row, GroupByQuery
DimensionHandlerUtils.convertObjectToType(results.next(), dimensionSpec.getOutputType())
);
}
Iterator<AggregatorFactory> aggsIter = aggs.iterator();
while (aggsIter.hasNext() && results.hasNext()) {
final AggregatorFactory factory = aggsIter.next();
event.put(factory.getName(), factory.deserialize(results.next()));
}
CacheStrategy.fetchAggregatorsFromCache(
aggsIter,
results,
isResultLevelCache,
(aggName, aggValueObject) -> {
event.put(aggName, aggValueObject);
return null;
}
);
if (isResultLevelCache) {
Iterator<PostAggregator> postItr = query.getPostAggregatorSpecs().iterator();
while (postItr.hasNext() && results.hasNext()) {

View File

@ -327,17 +327,23 @@ public class TimeseriesQueryQueryToolChest extends QueryToolChest<Result<Timeser
public Result<TimeseriesResultValue> apply(@Nullable Object input)
{
List<Object> results = (List<Object>) input;
Map<String, Object> retVal = Maps.newLinkedHashMap();
final Map<String, Object> retVal = Maps.newLinkedHashMap();
Iterator<AggregatorFactory> aggsIter = aggs.iterator();
Iterator<Object> resultIter = results.iterator();
DateTime timestamp = granularity.toDateTime(((Number) resultIter.next()).longValue());
while (aggsIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggsIter.next();
retVal.put(factory.getName(), factory.deserialize(resultIter.next()));
}
CacheStrategy.fetchAggregatorsFromCache(
aggsIter,
resultIter,
isResultLevelCache,
(aggName, aggValueObject) -> {
retVal.put(aggName, aggValueObject);
return null;
}
);
if (isResultLevelCache) {
Iterator<PostAggregator> postItr = query.getPostAggregatorSpecs().iterator();
while (postItr.hasNext() && resultIter.hasNext()) {

View File

@ -398,7 +398,7 @@ public class TopNQueryQueryToolChest extends QueryToolChest<Result<TopNResultVal
while (inputIter.hasNext()) {
List<Object> result = (List<Object>) inputIter.next();
Map<String, Object> vals = Maps.newLinkedHashMap();
final Map<String, Object> vals = Maps.newLinkedHashMap();
Iterator<AggregatorFactory> aggIter = aggs.iterator();
Iterator<Object> resultIter = result.iterator();
@ -409,10 +409,15 @@ public class TopNQueryQueryToolChest extends QueryToolChest<Result<TopNResultVal
DimensionHandlerUtils.convertObjectToType(resultIter.next(), query.getDimensionSpec().getOutputType())
);
while (aggIter.hasNext() && resultIter.hasNext()) {
final AggregatorFactory factory = aggIter.next();
vals.put(factory.getName(), factory.deserialize(resultIter.next()));
}
CacheStrategy.fetchAggregatorsFromCache(
aggIter,
resultIter,
isResultLevelCache,
(aggName, aggValueObject) -> {
vals.put(aggName, aggValueObject);
return null;
}
);
for (PostAggregator postAgg : postAggs) {
vals.put(postAgg.getName(), postAgg.compute(vals));

View File

@ -19,15 +19,26 @@
package org.apache.druid.query.groupby;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.data.input.Row;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.query.CacheStrategy;
import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FloatSumAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.ExpressionPostAggregator;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.expression.TestExprMacroTable;
@ -46,10 +57,14 @@ import org.apache.druid.query.groupby.having.OrHavingSpec;
import org.apache.druid.query.groupby.orderby.DefaultLimitSpec;
import org.apache.druid.query.groupby.orderby.OrderByColumnSpec;
import org.apache.druid.query.ordering.StringComparators;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.column.ValueType;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class GroupByQueryQueryToolChestTest
@ -483,4 +498,143 @@ public class GroupByQueryQueryToolChestTest
));
}
@Test
public void testCacheStrategy() throws Exception
{
doTestCacheStrategy(ValueType.STRING, "val1");
doTestCacheStrategy(ValueType.FLOAT, 2.1f);
doTestCacheStrategy(ValueType.DOUBLE, 2.1d);
doTestCacheStrategy(ValueType.LONG, 2L);
}
private AggregatorFactory getComplexAggregatorFactoryForValueType(final ValueType valueType)
{
switch (valueType) {
case LONG:
return new LongLastAggregatorFactory("complexMetric", "test");
case DOUBLE:
return new DoubleLastAggregatorFactory("complexMetric", "test");
case FLOAT:
return new FloatLastAggregatorFactory("complexMetric", "test");
case STRING:
return new StringLastAggregatorFactory("complexMetric", "test", null);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}
private SerializablePair getIntermediateComplexValue(final ValueType valueType, final Object dimValue)
{
switch (valueType) {
case LONG:
case DOUBLE:
case FLOAT:
return new SerializablePair<>(123L, dimValue);
case STRING:
return new SerializablePairLongString(123L, (String) dimValue);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}
private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException
{
final GroupByQuery query1 = GroupByQuery
.builder()
.setDataSource(QueryRunnerTestHelper.dataSource)
.setQuerySegmentSpec(QueryRunnerTestHelper.firstToThird)
.setDimensions(Collections.singletonList(
new DefaultDimensionSpec("test", "test", valueType)
))
.setAggregatorSpecs(
Arrays.asList(
QueryRunnerTestHelper.rowsCount,
getComplexAggregatorFactoryForValueType(valueType)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(new ConstantPostAggregator("post", 10))
)
.setGranularity(QueryRunnerTestHelper.dayGran)
.build();
CacheStrategy<Row, Object, GroupByQuery> strategy =
new GroupByQueryQueryToolChest(null, null).getCacheStrategy(
query1
);
final Row result1 = new MapBasedRow(
// test timestamps that result in integer size millis
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", getIntermediateComplexValue(valueType, dimValue)
)
);
Object preparedValue = strategy.prepareForSegmentLevelCache().apply(
result1
);
ObjectMapper objectMapper = TestHelper.makeJsonMapper();
Object fromCacheValue = objectMapper.readValue(
objectMapper.writeValueAsBytes(preparedValue),
strategy.getCacheObjectClazz()
);
Row fromCacheResult = strategy.pullFromSegmentLevelCache().apply(fromCacheValue);
Assert.assertEquals(result1, fromCacheResult);
final Row result2 = new MapBasedRow(
// test timestamps that result in integer size millis
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", dimValue,
"post", 10
)
);
// Please see the comments on aggregator serde and type handling in CacheStrategy.fetchAggregatorsFromCache()
final Row typeAdjustedResult2;
if (valueType == ValueType.FLOAT) {
typeAdjustedResult2 = new MapBasedRow(
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", 2.1d,
"post", 10
)
);
} else if (valueType == ValueType.LONG) {
typeAdjustedResult2 = new MapBasedRow(
DateTimes.utc(123L),
ImmutableMap.of(
"test", dimValue,
"rows", 1,
"complexMetric", 2,
"post", 10
)
);
} else {
typeAdjustedResult2 = result2;
}
Object preparedResultCacheValue = strategy.prepareForCache(true).apply(
result2
);
Object fromResultCacheValue = objectMapper.readValue(
objectMapper.writeValueAsBytes(preparedResultCacheValue),
strategy.getCacheObjectClazz()
);
Row fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue);
Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult);
}
}

View File

@ -32,6 +32,8 @@ import org.apache.druid.query.Result;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
@ -77,7 +79,8 @@ public class TimeseriesQueryQueryToolChestTest
Granularities.ALL,
ImmutableList.of(
new CountAggregatorFactory("metric1"),
new LongSumAggregatorFactory("metric0", "metric0")
new LongSumAggregatorFactory("metric0", "metric0"),
new StringLastAggregatorFactory("complexMetric", "test", null)
),
ImmutableList.of(new ConstantPostAggregator("post", 10)),
0,
@ -89,7 +92,11 @@ public class TimeseriesQueryQueryToolChestTest
// test timestamps that result in integer size millis
DateTimes.utc(123L),
new TimeseriesResultValue(
ImmutableMap.of("metric1", 2, "metric0", 3)
ImmutableMap.of(
"metric1", 2,
"metric0", 3,
"complexMetric", new SerializablePairLongString(123L, "val1")
)
)
);
@ -109,7 +116,12 @@ public class TimeseriesQueryQueryToolChestTest
// test timestamps that result in integer size millis
DateTimes.utc(123L),
new TimeseriesResultValue(
ImmutableMap.of("metric1", 2, "metric0", 3, "post", 10)
ImmutableMap.of(
"metric1", 2,
"metric0", 3,
"complexMetric", "val1",
"post", 10
)
)
);

View File

@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.collections.CloseableStupidPool;
import org.apache.druid.collections.SerializablePair;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
@ -35,8 +36,14 @@ import org.apache.druid.query.QueryRunnerTestHelper;
import org.apache.druid.query.Result;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.TestQueryRunners;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.SerializablePairLongString;
import org.apache.druid.query.aggregation.last.DoubleLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.FloatLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.LongLastAggregatorFactory;
import org.apache.druid.query.aggregation.last.StringLastAggregatorFactory;
import org.apache.druid.query.aggregation.post.ArithmeticPostAggregator;
import org.apache.druid.query.aggregation.post.ConstantPostAggregator;
import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator;
@ -269,6 +276,36 @@ public class TopNQueryQueryToolChestTest
}
}
private AggregatorFactory getComplexAggregatorFactoryForValueType(final ValueType valueType)
{
switch (valueType) {
case LONG:
return new LongLastAggregatorFactory("complexMetric", "test");
case DOUBLE:
return new DoubleLastAggregatorFactory("complexMetric", "test");
case FLOAT:
return new FloatLastAggregatorFactory("complexMetric", "test");
case STRING:
return new StringLastAggregatorFactory("complexMetric", "test", null);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}
private SerializablePair getIntermediateComplexValue(final ValueType valueType, final Object dimValue)
{
switch (valueType) {
case LONG:
case DOUBLE:
case FLOAT:
return new SerializablePair<>(123L, dimValue);
case STRING:
return new SerializablePairLongString(123L, (String) dimValue);
default:
throw new IllegalArgumentException("bad valueType: " + valueType);
}
}
private void doTestCacheStrategy(final ValueType valueType, final Object dimValue) throws IOException
{
CacheStrategy<Result<TopNResultValue>, Object, TopNQuery> strategy =
@ -282,7 +319,10 @@ public class TopNQueryQueryToolChestTest
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("2015-01-01/2015-01-02"))),
null,
Granularities.ALL,
ImmutableList.of(new CountAggregatorFactory("metric1")),
ImmutableList.of(
new CountAggregatorFactory("metric1"),
getComplexAggregatorFactoryForValueType(valueType)
),
ImmutableList.of(new ConstantPostAggregator("post", 10)),
null
)
@ -295,7 +335,8 @@ public class TopNQueryQueryToolChestTest
Collections.singletonList(
ImmutableMap.of(
"test", dimValue,
"metric1", 2
"metric1", 2,
"complexMetric", getIntermediateComplexValue(valueType, dimValue)
)
)
)
@ -323,12 +364,48 @@ public class TopNQueryQueryToolChestTest
ImmutableMap.of(
"test", dimValue,
"metric1", 2,
"complexMetric", dimValue,
"post", 10
)
)
)
);
// Please see the comments on aggregator serde and type handling in CacheStrategy.fetchAggregatorsFromCache()
final Result<TopNResultValue> typeAdjustedResult2;
if (valueType == ValueType.FLOAT) {
typeAdjustedResult2 = new Result<>(
DateTimes.utc(123L),
new TopNResultValue(
Collections.singletonList(
ImmutableMap.of(
"test", dimValue,
"metric1", 2,
"complexMetric", 2.1d,
"post", 10
)
)
)
);
} else if (valueType == ValueType.LONG) {
typeAdjustedResult2 = new Result<>(
DateTimes.utc(123L),
new TopNResultValue(
Collections.singletonList(
ImmutableMap.of(
"test", dimValue,
"metric1", 2,
"complexMetric", 2,
"post", 10
)
)
)
);
} else {
typeAdjustedResult2 = result2;
}
Object preparedResultCacheValue = strategy.prepareForCache(true).apply(
result2
);
@ -339,7 +416,7 @@ public class TopNQueryQueryToolChestTest
);
Result<TopNResultValue> fromResultCacheResult = strategy.pullFromCache(true).apply(fromResultCacheValue);
Assert.assertEquals(result2, fromResultCacheResult);
Assert.assertEquals(typeAdjustedResult2, fromResultCacheResult);
}
static class MockQueryRunner implements QueryRunner<Result<TopNResultValue>>