fix filtered Aggregator

fix filtered Aggregator
remove unused name parameter for filtered aggregator
add tests
This commit is contained in:
nishantmonu51 2014-11-19 16:40:25 +05:30
parent f2d94eecde
commit 6fd37ce023
5 changed files with 291 additions and 37 deletions

View File

@ -28,14 +28,12 @@ import javax.annotation.Nullable;
public class FilteredAggregator implements Aggregator public class FilteredAggregator implements Aggregator
{ {
private final String name;
private final DimensionSelector dimSelector; private final DimensionSelector dimSelector;
private final Aggregator delegate; private final Aggregator delegate;
private final IntPredicate predicate; private final IntPredicate predicate;
public FilteredAggregator(String name, DimensionSelector dimSelector, IntPredicate predicate, Aggregator delegate) public FilteredAggregator(DimensionSelector dimSelector, IntPredicate predicate, Aggregator delegate)
{ {
this.name = name;
this.dimSelector = dimSelector; this.dimSelector = dimSelector;
this.delegate = delegate; this.delegate = delegate;
this.predicate = predicate; this.predicate = predicate;
@ -75,7 +73,7 @@ public class FilteredAggregator implements Aggregator
@Override @Override
public String getName() public String getName()
{ {
return name; return delegate.getName();
} }
@Override @Override

View File

@ -22,7 +22,6 @@ package io.druid.query.aggregation;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.metamx.common.ISE; import com.metamx.common.ISE;
import com.metamx.common.Pair;
import io.druid.query.filter.DimFilter; import io.druid.query.filter.DimFilter;
import io.druid.query.filter.NotDimFilter; import io.druid.query.filter.NotDimFilter;
import io.druid.query.filter.SelectorDimFilter; import io.druid.query.filter.SelectorDimFilter;
@ -37,12 +36,10 @@ public class FilteredAggregatorFactory implements AggregatorFactory
{ {
private static final byte CACHE_TYPE_ID = 0x9; private static final byte CACHE_TYPE_ID = 0x9;
private final String name;
private final AggregatorFactory delegate; private final AggregatorFactory delegate;
private final DimFilter filter; private final DimFilter filter;
public FilteredAggregatorFactory( public FilteredAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("aggregator") AggregatorFactory delegate, @JsonProperty("aggregator") AggregatorFactory delegate,
@JsonProperty("filter") DimFilter filter @JsonProperty("filter") DimFilter filter
) )
@ -55,7 +52,6 @@ public class FilteredAggregatorFactory implements AggregatorFactory
"FilteredAggregator currently only supports filters of type 'selector' and their negation" "FilteredAggregator currently only supports filters of type 'selector' and their negation"
); );
this.name = name;
this.delegate = delegate; this.delegate = delegate;
this.filter = filter; this.filter = filter;
} }
@ -64,22 +60,46 @@ public class FilteredAggregatorFactory implements AggregatorFactory
public Aggregator factorize(ColumnSelectorFactory metricFactory) public Aggregator factorize(ColumnSelectorFactory metricFactory)
{ {
final Aggregator aggregator = delegate.factorize(metricFactory); final Aggregator aggregator = delegate.factorize(metricFactory);
final Pair<DimensionSelector, IntPredicate> selectorPredicatePair = makeFilterPredicate( SelectorDimFilter selector = getSelector(filter);
filter, final DimensionSelector dimensionSelector = metricFactory.makeDimensionSelector(selector.getDimension());
metricFactory if (dimensionSelector == null) {
// dimension does not exist
if (filter instanceof NotDimFilter) {
// all rows match the not criteria
return aggregator;
} else {
// none row match the selector filter
return Aggregators.noopAggregator();
}
}
return new FilteredAggregator(
dimensionSelector,
makeFilterPredicate(filter, dimensionSelector, selector.getValue()),
aggregator
); );
return new FilteredAggregator(name, selectorPredicatePair.lhs, selectorPredicatePair.rhs, aggregator);
} }
@Override @Override
public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory) public BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory)
{ {
final BufferAggregator aggregator = delegate.factorizeBuffered(metricFactory); final BufferAggregator aggregator = delegate.factorizeBuffered(metricFactory);
final Pair<DimensionSelector, IntPredicate> selectorPredicatePair = makeFilterPredicate( SelectorDimFilter selector = getSelector(filter);
filter, final DimensionSelector dimensionSelector = metricFactory.makeDimensionSelector(selector.getDimension());
metricFactory if (dimensionSelector == null) {
// dimension does not exist
if (filter instanceof NotDimFilter) {
// all rows match the not criteria
return aggregator;
} else {
// none row match the selector filter
return Aggregators.noopBufferAggregator();
}
}
return new FilteredBufferAggregator(
dimensionSelector,
makeFilterPredicate(filter, dimensionSelector, selector.getValue()),
aggregator
); );
return new FilteredBufferAggregator(selectorPredicatePair.lhs, selectorPredicatePair.rhs, aggregator);
} }
@Override @Override
@ -116,7 +136,7 @@ public class FilteredAggregatorFactory implements AggregatorFactory
@Override @Override
public String getName() public String getName()
{ {
return name; return delegate.getName();
} }
@Override @Override
@ -173,23 +193,13 @@ public class FilteredAggregatorFactory implements AggregatorFactory
return delegate.getRequiredColumns(); return delegate.getRequiredColumns();
} }
private static Pair<DimensionSelector, IntPredicate> makeFilterPredicate( private IntPredicate makeFilterPredicate(
final DimFilter dimFilter, final DimFilter dimFilter,
final ColumnSelectorFactory metricFactory final DimensionSelector dimSelector,
final String filterValue
) )
{ {
final SelectorDimFilter selector; final int lookupId = dimSelector.lookupId(filterValue);
if (dimFilter instanceof NotDimFilter) {
// we only support NotDimFilter with Selector filter
selector = (SelectorDimFilter) ((NotDimFilter) dimFilter).getField();
} else if (dimFilter instanceof SelectorDimFilter) {
selector = (SelectorDimFilter) dimFilter;
} else {
throw new ISE("Unsupported DimFilter type [%d]", dimFilter.getClass());
}
final DimensionSelector dimSelector = metricFactory.makeDimensionSelector(selector.getDimension());
final int lookupId = dimSelector.lookupId(selector.getValue());
final IntPredicate predicate; final IntPredicate predicate;
if (dimFilter instanceof NotDimFilter) { if (dimFilter instanceof NotDimFilter) {
predicate = new IntPredicate() predicate = new IntPredicate()
@ -210,7 +220,59 @@ public class FilteredAggregatorFactory implements AggregatorFactory
} }
}; };
} }
return Pair.of(dimSelector, predicate); return predicate;
} }
public static SelectorDimFilter getSelector(DimFilter dimFilter)
{
final SelectorDimFilter selector;
if (dimFilter instanceof NotDimFilter) {
// we only support NotDimFilter with Selector filter
selector = (SelectorDimFilter) ((NotDimFilter) dimFilter).getField();
} else if (dimFilter instanceof SelectorDimFilter) {
selector = (SelectorDimFilter) dimFilter;
} else {
throw new ISE("Unsupported DimFilter type [%d]", dimFilter.getClass());
}
return selector;
}
@Override
public String toString()
{
return "FilteredAggregatorFactory{" +
", delegate=" + delegate +
", filter=" + filter +
'}';
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FilteredAggregatorFactory that = (FilteredAggregatorFactory) o;
if (delegate != null ? !delegate.equals(that.delegate) : that.delegate != null) {
return false;
}
if (filter != null ? !filter.equals(that.filter) : that.filter != null) {
return false;
}
return true;
}
@Override
public int hashCode()
{
int result = delegate != null ? delegate.hashCode() : 0;
result = 31 * result + (filter != null ? filter.hashCode() : 0);
return result;
}
} }

View File

@ -837,7 +837,11 @@ public class IncrementalIndex implements Iterable<Row>, Closeable
public int getId(String value) public int getId(String value)
{ {
return falseIds.get(value); if (value == null) {
return -1;
}
final Integer id = falseIds.get(value);
return id == null ? -1 : id;
} }
public String getValue(int id) public String getValue(int id)

View File

@ -46,7 +46,6 @@ public class FilteredAggregatorTest
final TestFloatColumnSelector selector = new TestFloatColumnSelector(values); final TestFloatColumnSelector selector = new TestFloatColumnSelector(values);
FilteredAggregatorFactory factory = new FilteredAggregatorFactory( FilteredAggregatorFactory factory = new FilteredAggregatorFactory(
"test",
new DoubleSumAggregatorFactory("billy", "value"), new DoubleSumAggregatorFactory("billy", "value"),
new SelectorDimFilter("dim", "a") new SelectorDimFilter("dim", "a")
); );
@ -55,7 +54,7 @@ public class FilteredAggregatorTest
makeColumnSelector(selector) makeColumnSelector(selector)
); );
Assert.assertEquals("test", agg.getName()); Assert.assertEquals("billy", agg.getName());
double expectedFirst = new Float(values[0]).doubleValue(); double expectedFirst = new Float(values[0]).doubleValue();
double expectedSecond = new Float(values[1]).doubleValue() + expectedFirst; double expectedSecond = new Float(values[1]).doubleValue() + expectedFirst;
@ -164,7 +163,6 @@ public class FilteredAggregatorTest
final TestFloatColumnSelector selector = new TestFloatColumnSelector(values); final TestFloatColumnSelector selector = new TestFloatColumnSelector(values);
FilteredAggregatorFactory factory = new FilteredAggregatorFactory( FilteredAggregatorFactory factory = new FilteredAggregatorFactory(
"test",
new DoubleSumAggregatorFactory("billy", "value"), new DoubleSumAggregatorFactory("billy", "value"),
new NotDimFilter(new SelectorDimFilter("dim", "b")) new NotDimFilter(new SelectorDimFilter("dim", "b"))
); );
@ -173,7 +171,7 @@ public class FilteredAggregatorTest
makeColumnSelector(selector) makeColumnSelector(selector)
); );
Assert.assertEquals("test", agg.getName()); Assert.assertEquals("billy", agg.getName());
double expectedFirst = new Float(values[0]).doubleValue(); double expectedFirst = new Float(values[0]).doubleValue();
double expectedSecond = new Float(values[1]).doubleValue() + expectedFirst; double expectedSecond = new Float(values[1]).doubleValue() + expectedFirst;

View File

@ -32,12 +32,15 @@ import io.druid.query.QueryRunner;
import io.druid.query.QueryRunnerTestHelper; import io.druid.query.QueryRunnerTestHelper;
import io.druid.query.Result; import io.druid.query.Result;
import io.druid.query.aggregation.AggregatorFactory; import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.CountAggregatorFactory;
import io.druid.query.aggregation.FilteredAggregatorFactory;
import io.druid.query.aggregation.LongSumAggregatorFactory; import io.druid.query.aggregation.LongSumAggregatorFactory;
import io.druid.query.aggregation.MaxAggregatorFactory; import io.druid.query.aggregation.MaxAggregatorFactory;
import io.druid.query.aggregation.MinAggregatorFactory; import io.druid.query.aggregation.MinAggregatorFactory;
import io.druid.query.aggregation.PostAggregator; import io.druid.query.aggregation.PostAggregator;
import io.druid.query.filter.AndDimFilter; import io.druid.query.filter.AndDimFilter;
import io.druid.query.filter.DimFilter; import io.druid.query.filter.DimFilter;
import io.druid.query.filter.NotDimFilter;
import io.druid.query.filter.RegexDimFilter; import io.druid.query.filter.RegexDimFilter;
import io.druid.query.spec.MultipleIntervalSegmentSpec; import io.druid.query.spec.MultipleIntervalSegmentSpec;
import io.druid.segment.TestHelper; import io.druid.segment.TestHelper;
@ -1658,4 +1661,193 @@ public class TimeseriesQueryRunnerTest
); );
TestHelper.assertExpectedResults(expectedResults, actualResults); TestHelper.assertExpectedResults(expectedResults, actualResults);
} }
@Test
public void testTimeSeriesWithFilteredAgg()
{
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.dataSource)
.granularity(QueryRunnerTestHelper.allGran)
.intervals(QueryRunnerTestHelper.firstToThird)
.aggregators(
Lists.newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new FilteredAggregatorFactory(
new CountAggregatorFactory("filteredAgg"),
Druids.newSelectorDimFilterBuilder()
.dimension(QueryRunnerTestHelper.providerDimension)
.value("spot")
.build()
)
)
)
)
)
.postAggregators(Arrays.<PostAggregator>asList(QueryRunnerTestHelper.addRowsIndexConstant))
.build();
Iterable<Result<TimeseriesResultValue>> actualResults = Sequences.toList(
runner.run(query),
Lists.<Result<TimeseriesResultValue>>newArrayList()
);
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
new Result<TimeseriesResultValue>(
new DateTime("2011-04-01"),
new TimeseriesResultValue(
ImmutableMap.<String, Object>of(
"filteredAgg", 18L,
"addRowsIndexConstant", 12486.361190795898d,
"index", 12459.361190795898d,
"uniques", 9.019833517963864d,
"rows", 26L
)
)
)
);
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
@Test
public void testTimeSeriesWithFilteredAggDimensionNotPresent(){
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.dataSource)
.granularity(QueryRunnerTestHelper.allGran)
.intervals(QueryRunnerTestHelper.firstToThird)
.aggregators(
Lists.newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new FilteredAggregatorFactory(
new CountAggregatorFactory("filteredAgg"),
new NotDimFilter(Druids.newSelectorDimFilterBuilder()
.dimension("abraKaDabra")
.value("Lol")
.build())
)
)
)
)
)
.postAggregators(Arrays.<PostAggregator>asList(QueryRunnerTestHelper.addRowsIndexConstant))
.build();
Iterable<Result<TimeseriesResultValue>> actualResults = Sequences.toList(
runner.run(query),
Lists.<Result<TimeseriesResultValue>>newArrayList()
);
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
new Result<TimeseriesResultValue>(
new DateTime("2011-04-01"),
new TimeseriesResultValue(
ImmutableMap.<String, Object>of(
"filteredAgg", 26L,
"addRowsIndexConstant", 12486.361190795898d,
"index", 12459.361190795898d,
"uniques", 9.019833517963864d,
"rows", 26L
)
)
)
);
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
@Test
public void testTimeSeriesWithFilteredAggValueNotPresent(){
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.dataSource)
.granularity(QueryRunnerTestHelper.allGran)
.intervals(QueryRunnerTestHelper.firstToThird)
.aggregators(
Lists.newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new FilteredAggregatorFactory(
new CountAggregatorFactory("filteredAgg"),
new NotDimFilter(Druids.newSelectorDimFilterBuilder()
.dimension(QueryRunnerTestHelper.providerDimension)
.value("LolLol")
.build())
)
)
)
)
)
.postAggregators(Arrays.<PostAggregator>asList(QueryRunnerTestHelper.addRowsIndexConstant))
.build();
Iterable<Result<TimeseriesResultValue>> actualResults = Sequences.toList(
runner.run(query),
Lists.<Result<TimeseriesResultValue>>newArrayList()
);
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
new Result<TimeseriesResultValue>(
new DateTime("2011-04-01"),
new TimeseriesResultValue(
ImmutableMap.<String, Object>of(
"filteredAgg", 26L,
"addRowsIndexConstant", 12486.361190795898d,
"index", 12459.361190795898d,
"uniques", 9.019833517963864d,
"rows", 26L
)
)
)
);
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
@Test
public void testTimeSeriesWithFilteredAggNullValue(){
TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(QueryRunnerTestHelper.dataSource)
.granularity(QueryRunnerTestHelper.allGran)
.intervals(QueryRunnerTestHelper.firstToThird)
.aggregators(
Lists.newArrayList(
Iterables.concat(
QueryRunnerTestHelper.commonAggregators,
Lists.newArrayList(
new FilteredAggregatorFactory(
new CountAggregatorFactory("filteredAgg"),
new NotDimFilter(Druids.newSelectorDimFilterBuilder()
.dimension(QueryRunnerTestHelper.providerDimension)
.value(null)
.build())
)
)
)
)
)
.postAggregators(Arrays.<PostAggregator>asList(QueryRunnerTestHelper.addRowsIndexConstant))
.build();
Iterable<Result<TimeseriesResultValue>> actualResults = Sequences.toList(
runner.run(query),
Lists.<Result<TimeseriesResultValue>>newArrayList()
);
List<Result<TimeseriesResultValue>> expectedResults = Arrays.asList(
new Result<TimeseriesResultValue>(
new DateTime("2011-04-01"),
new TimeseriesResultValue(
ImmutableMap.<String, Object>of(
"filteredAgg", 26L,
"addRowsIndexConstant", 12486.361190795898d,
"index", 12459.361190795898d,
"uniques", 9.019833517963864d,
"rows", 26L
)
)
)
);
TestHelper.assertExpectedResults(expectedResults, actualResults);
}
} }