Merge pull request #2425 from will-lauer/sketchErrorBounds

Adding optional error bounds to sketch aggs and post-aggs
This commit is contained in:
Himanshu 2016-02-11 11:41:17 -06:00
commit 7a0bfa693b
9 changed files with 271 additions and 27 deletions

View File

@ -35,15 +35,18 @@ public class SketchEstimatePostAggregator implements PostAggregator
private final String name;
private final PostAggregator field;
private final Integer errorBoundsStdDev;
@JsonCreator
public SketchEstimatePostAggregator(
@JsonProperty("name") String name,
@JsonProperty("field") PostAggregator field
@JsonProperty("field") PostAggregator field,
@JsonProperty("errorBoundsStdDev") Integer errorBoundsStdDev
)
{
this.name = Preconditions.checkNotNull(name, "name is null");
this.field = Preconditions.checkNotNull(field, "field is null");
this.errorBoundsStdDev = errorBoundsStdDev;
}
@Override
@ -64,8 +67,17 @@ public class SketchEstimatePostAggregator implements PostAggregator
public Object compute(Map<String, Object> combinedAggregators)
{
Sketch sketch = (Sketch) field.compute(combinedAggregators);
if (errorBoundsStdDev != null) {
SketchEstimateWithErrorBounds result = new SketchEstimateWithErrorBounds(
sketch.getEstimate(),
sketch.getUpperBound(errorBoundsStdDev),
sketch.getLowerBound(errorBoundsStdDev),
errorBoundsStdDev);
return result;
} else {
return sketch.getEstimate();
}
}
@Override
@JsonProperty
@ -80,12 +92,19 @@ public class SketchEstimatePostAggregator implements PostAggregator
return field;
}
@JsonProperty
public Integer getErrorBoundsStdDev()
{
return errorBoundsStdDev;
}
@Override
public String toString()
{
return "SketchEstimatePostAggregator{" +
"name='" + name + '\'' +
", field=" + field +
", errorBoundsStdDev=" + errorBoundsStdDev +
"}";
}
@ -104,6 +123,9 @@ public class SketchEstimatePostAggregator implements PostAggregator
if (!name.equals(that.name)) {
return false;
}
if (errorBoundsStdDev != that.errorBoundsStdDev) {
return false;
}
return field.equals(that.field);
}
@ -113,6 +135,7 @@ public class SketchEstimatePostAggregator implements PostAggregator
{
int result = name.hashCode();
result = 31 * result + field.hashCode();
result = 31 * result + (errorBoundsStdDev != null ? errorBoundsStdDev.hashCode() : 0);
return result;
}
}

View File

@ -0,0 +1,113 @@
/**
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.datasketches.theta;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
/**
* Container class used to return estimates in conjunction with
* estimated error bounds.
*/
public class SketchEstimateWithErrorBounds
{
final private double estimate;
final private double highBound;
final private double lowBound;
final private int numStdDev;
@JsonCreator
public SketchEstimateWithErrorBounds(
@JsonProperty("estimate") double estimate,
@JsonProperty("highBound") double highBound,
@JsonProperty("lowBound") double lowBound,
@JsonProperty("numStdDev") int numStdDev
)
{
this.estimate = estimate;
this.highBound = highBound;
this.lowBound = lowBound;
this.numStdDev = numStdDev;
}
@JsonProperty
public double getEstimate()
{
return estimate;
}
@JsonProperty
public double getHighBound()
{
return highBound;
}
@JsonProperty
public double getLowBound()
{
return lowBound;
}
@JsonProperty
public int getNumStdDev()
{
return numStdDev;
}
@Override
public String toString()
{
return "SketchEstimateWithErrorBounds{" +
"estimate=" + Double.toString(estimate) +
", highBound=" + Double.toString(highBound) +
", lowBound="+ Double.toString(lowBound) +
", numStdDev=" + Integer.toString(numStdDev) +
"}";
}
@Override
public int hashCode()
{
int result = Double.valueOf(estimate).hashCode();
result = 31 * result + Double.valueOf(highBound).hashCode();
result = 31 * result + Double.valueOf(lowBound).hashCode();
result = 31 * result + Integer.valueOf(numStdDev).hashCode();
return result;
}
@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
} else if (obj == null || getClass() != obj.getClass()) {
return false;
}
SketchEstimateWithErrorBounds that = (SketchEstimateWithErrorBounds) obj;
if (estimate != that.estimate ||
highBound != that.highBound ||
lowBound != that.lowBound ||
numStdDev != numStdDev) {
return false;
}
return true;
}
}

View File

@ -35,6 +35,7 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
private final boolean shouldFinalize;
private final boolean isInputThetaSketch;
private final Integer errorBoundsStdDev;
@JsonCreator
public SketchMergeAggregatorFactory(
@ -42,12 +43,14 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
@JsonProperty("fieldName") String fieldName,
@JsonProperty("size") Integer size,
@JsonProperty("shouldFinalize") Boolean shouldFinalize,
@JsonProperty("isInputThetaSketch") Boolean isInputThetaSketch
@JsonProperty("isInputThetaSketch") Boolean isInputThetaSketch,
@JsonProperty("errorBoundsStdDev") Integer errorBoundsStdDev
)
{
super(name, fieldName, size, CACHE_TYPE_ID);
this.shouldFinalize = (shouldFinalize == null) ? true : shouldFinalize.booleanValue();
this.isInputThetaSketch = (isInputThetaSketch == null) ? false : isInputThetaSketch.booleanValue();
this.errorBoundsStdDev = errorBoundsStdDev;
}
@Override
@ -59,7 +62,8 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
fieldName,
size,
shouldFinalize,
isInputThetaSketch
isInputThetaSketch,
errorBoundsStdDev
)
);
}
@ -67,7 +71,7 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
@Override
public AggregatorFactory getCombiningFactory()
{
return new SketchMergeAggregatorFactory(name, name, size, shouldFinalize, false);
return new SketchMergeAggregatorFactory(name, name, size, shouldFinalize, false, errorBoundsStdDev);
}
@Override
@ -81,7 +85,8 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
name,
Math.max(size, castedOther.size),
shouldFinalize,
true
true,
errorBoundsStdDev
);
} else {
throw new AggregatorFactoryNotMergeableException(this, other);
@ -100,6 +105,12 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
return isInputThetaSketch;
}
@JsonProperty
public Integer getErrorBoundsStdDev()
{
return errorBoundsStdDev;
}
/**
* Finalize the computation on sketch object and returns estimate from underlying
* sketch.
@ -112,7 +123,17 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
public Object finalizeComputation(Object object)
{
if (shouldFinalize) {
return ((Sketch) object).getEstimate();
Sketch sketch = (Sketch) object;
if (errorBoundsStdDev != null) {
SketchEstimateWithErrorBounds result = new SketchEstimateWithErrorBounds(
sketch.getEstimate(),
sketch.getUpperBound(errorBoundsStdDev),
sketch.getLowerBound(errorBoundsStdDev),
errorBoundsStdDev);
return result;
} else {
return sketch.getEstimate();
}
} else {
return object;
}
@ -146,6 +167,9 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
if (shouldFinalize != that.shouldFinalize) {
return false;
}
if (errorBoundsStdDev != that.errorBoundsStdDev) {
return false;
}
return isInputThetaSketch == that.isInputThetaSketch;
}
@ -156,6 +180,7 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
int result = super.hashCode();
result = 31 * result + (shouldFinalize ? 1 : 0);
result = 31 * result + (isInputThetaSketch ? 1 : 0);
result = 31 * result + (errorBoundsStdDev != null ? errorBoundsStdDev.hashCode() : 0);
return result;
}
@ -166,8 +191,9 @@ public class SketchMergeAggregatorFactory extends SketchAggregatorFactory
+ "fieldName=" + fieldName
+ ", name=" + name
+ ", size=" + size
+ ",shouldFinalize=" + shouldFinalize
+ ", shouldFinalize=" + shouldFinalize
+ ", isInputThetaSketch=" + isInputThetaSketch
+ ", errorBoundsStdDev=" + errorBoundsStdDev
+ "}";
}
}

View File

@ -34,6 +34,6 @@ public class OldSketchBuildAggregatorFactory extends SketchMergeAggregatorFactor
@JsonProperty("size") Integer size
)
{
super(name, fieldName, size, true, false);
super(name, fieldName, size, true, false, null);
}
}

View File

@ -34,6 +34,6 @@ public class OldSketchEstimatePostAggregator extends SketchEstimatePostAggregato
@JsonProperty("field") PostAggregator field
)
{
super(name, field);
super(name, field, null);
}
}

View File

@ -35,6 +35,6 @@ public class OldSketchMergeAggregatorFactory extends SketchMergeAggregatorFactor
@JsonProperty("shouldFinalize") Boolean shouldFinalize
)
{
super(name, fieldName, size, shouldFinalize, true);
super(name, fieldName, size, shouldFinalize, true, null);
}
}

View File

@ -28,6 +28,7 @@ import com.metamx.common.guava.Sequences;
import com.yahoo.sketches.theta.Sketch;
import com.yahoo.sketches.theta.Sketches;
import io.druid.data.input.MapBasedRow;
import io.druid.data.input.Row;
import io.druid.granularity.QueryGranularity;
import io.druid.query.Result;
import io.druid.query.aggregation.AggregationTestHelper;
@ -65,7 +66,7 @@ public class SketchAggregationTest
@Test
public void testSimpleDataIngestAndGpByQuery() throws Exception
{
Sequence seq = helper.createIndexAndRunQueryOnSegment(
Sequence<Row> seq = helper.createIndexAndRunQueryOnSegment(
new File(this.getClass().getClassLoader().getResource("simple_test_data.tsv").getFile()),
readFileFromClasspathAsString("simple_test_data_record_parser.json"),
readFileFromClasspathAsString("simple_test_data_aggregators.json"),
@ -75,7 +76,7 @@ public class SketchAggregationTest
readFileFromClasspathAsString("simple_test_data_group_by_query.json")
);
List results = Sequences.toList(seq, Lists.newArrayList());
List<Row> results = Sequences.toList(seq, Lists.<Row>newArrayList());
Assert.assertEquals(1, results.size());
Assert.assertEquals(
new MapBasedRow(
@ -123,7 +124,7 @@ public class SketchAggregationTest
@Test
public void testSketchDataIngestAndGpByQuery() throws Exception
{
Sequence seq = helper.createIndexAndRunQueryOnSegment(
Sequence<Row> seq = helper.createIndexAndRunQueryOnSegment(
new File(SketchAggregationTest.class.getClassLoader().getResource("sketch_test_data.tsv").getFile()),
readFileFromClasspathAsString("sketch_test_data_record_parser.json"),
readFileFromClasspathAsString("sketch_test_data_aggregators.json"),
@ -133,7 +134,7 @@ public class SketchAggregationTest
readFileFromClasspathAsString("sketch_test_data_group_by_query.json")
);
List results = Sequences.toList(seq, Lists.newArrayList());
List<Row> results = Sequences.toList(seq, Lists.<Row>newArrayList());
Assert.assertEquals(1, results.size());
Assert.assertEquals(
new MapBasedRow(
@ -141,7 +142,11 @@ public class SketchAggregationTest
ImmutableMap
.<String, Object>builder()
.put("sids_sketch_count", 50.0)
.put("sids_sketch_count_with_err",
new SketchEstimateWithErrorBounds(50.0, 50.0, 50.0, 2))
.put("sketchEstimatePostAgg", 50.0)
.put("sketchEstimatePostAggWithErrorBounds",
new SketchEstimateWithErrorBounds(50.0, 50.0, 50.0, 2))
.put("sketchUnionPostAggEstimate", 50.0)
.put("sketchIntersectionPostAggEstimate", 50.0)
.put("sketchAnotBPostAggEstimate", 0.0)
@ -155,7 +160,7 @@ public class SketchAggregationTest
@Test
public void testThetaCardinalityOnSimpleColumn() throws Exception
{
Sequence seq = helper.createIndexAndRunQueryOnSegment(
Sequence<Row> seq = helper.createIndexAndRunQueryOnSegment(
new File(SketchAggregationTest.class.getClassLoader().getResource("simple_test_data.tsv").getFile()),
readFileFromClasspathAsString("simple_test_data_record_parser2.json"),
"["
@ -170,7 +175,7 @@ public class SketchAggregationTest
readFileFromClasspathAsString("simple_test_data_group_by_query.json")
);
List results = Sequences.toList(seq, Lists.newArrayList());
List<Row> results = Sequences.toList(seq, Lists.<Row>newArrayList());
Assert.assertEquals(1, results.size());
Assert.assertEquals(
new MapBasedRow(
@ -192,9 +197,10 @@ public class SketchAggregationTest
@Test
public void testSketchMergeAggregatorFactorySerde() throws Exception
{
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, null, null));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, false, true));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, true, false));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, null, null, null));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, false, true, null));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, true, false, null));
assertAggregatorFactorySerde(new SketchMergeAggregatorFactory("name", "fieldName", 16, true, false, 2));
}
@Test
@ -202,14 +208,22 @@ public class SketchAggregationTest
{
Sketch sketch = Sketches.updateSketchBuilder().build(128);
SketchMergeAggregatorFactory agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, null, null);
SketchMergeAggregatorFactory agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, null, null, null);
Assert.assertEquals(0.0, ((Double) agg.finalizeComputation(sketch)).doubleValue(), 0.0001);
agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, true, null);
agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, true, null, null);
Assert.assertEquals(0.0, ((Double) agg.finalizeComputation(sketch)).doubleValue(), 0.0001);
agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, false, null);
agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, false, null, null);
Assert.assertEquals(sketch, agg.finalizeComputation(sketch));
agg = new SketchMergeAggregatorFactory("name", "fieldName", 16, true, null, 2);
SketchEstimateWithErrorBounds est = (SketchEstimateWithErrorBounds) agg.finalizeComputation(sketch);
Assert.assertEquals(0.0, est.getEstimate(), 0.0001);
Assert.assertEquals(0.0, est.getHighBound(), 0.0001);
Assert.assertEquals(0.0, est.getLowBound(), 0.0001);
Assert.assertEquals(2, est.getNumStdDev());
}
private void assertAggregatorFactorySerde(AggregatorFactory agg) throws Exception
@ -229,7 +243,16 @@ public class SketchAggregationTest
assertPostAggregatorSerde(
new SketchEstimatePostAggregator(
"name",
new FieldAccessPostAggregator("name", "fieldName")
new FieldAccessPostAggregator("name", "fieldName"),
null
)
);
assertPostAggregatorSerde(
new SketchEstimatePostAggregator(
"name",
new FieldAccessPostAggregator("name", "fieldName"),
2
)
);
}

View File

@ -0,0 +1,43 @@
/**
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.datasketches.theta;
import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.druid.jackson.DefaultObjectMapper;
public class SketchEstimateWithErrorBoundsTest
{
@Test
public void testSerde() throws IOException
{
ObjectMapper mapper = new DefaultObjectMapper();
SketchEstimateWithErrorBounds est = new SketchEstimateWithErrorBounds(100.0,101.5,98.5,2);
Assert.assertEquals(est, mapper.readValue(
mapper.writeValueAsString(est), SketchEstimateWithErrorBounds.class));
}
}

View File

@ -10,6 +10,13 @@
"fieldName": "sids_sketch",
"size": 16384
},
{
"type": "thetaSketch",
"name": "sids_sketch_count_with_err",
"fieldName": "sids_sketch",
"size": 16384,
"errorBoundsStdDev": 2
},
{
"type": "thetaSketch",
"name": "non_existing_col_validation",
@ -26,6 +33,15 @@
"fieldName": "sids_sketch_count"
}
},
{
"type": "thetaSketchEstimate",
"name": "sketchEstimatePostAggWithErrorBounds",
"errorBoundsStdDev": 2,
"field": {
"type": "fieldAccess",
"fieldName": "sids_sketch_count"
}
},
{
"type": "thetaSketchEstimate",
"name": "sketchIntersectionPostAggEstimate",