adding new post aggregators for test statistics to druid-stats extension (#4532)

* adding new post aggregators of test stats to druid-stats extension

* changes to address code review comments

* fix checkstyle violations using druid_intellij_formatting.xml after merge upstream/master

* add @Override annotation per CI log

* make changes per review comments/discussions

* remove some blocks per review comments
This commit is contained in:
chunghochen 2017-10-09 23:43:27 -07:00 committed by Jonathan Wei
parent 7cc18226cd
commit 0614b92df1
8 changed files with 665 additions and 1 deletions

View File

@ -0,0 +1,98 @@
---
layout: doc_page
---
# Test Stats Aggregators
Incorporates test statistics related aggregators, including z-score and p-value. Please refer to [https://www.paypal-engineering.com/2017/06/29/democratizing-experimentation-data-for-product-innovations/](https://www.paypal-engineering.com/2017/06/29/democratizing-experimentation-data-for-product-innovations/) for math background and details.
Make sure to include `druid-stats` extension in order to use these aggregrators.
## Z-Score for two sample ztests post aggregator
Please refer to [https://www.isixsigma.com/tools-templates/hypothesis-testing/making-sense-two-proportions-test/](https://www.isixsigma.com/tools-templates/hypothesis-testing/making-sense-two-proportions-test/) and [http://www.ucs.louisiana.edu/~jcb0773/Berry_statbook/Berry_statbook_chpt6.pdf](http://www.ucs.louisiana.edu/~jcb0773/Berry_statbook/Berry_statbook_chpt6.pdf) for more details.
z = (p1 - p2) / S.E. (assuming null hypothesis is true)
Please see below for p1 and p2.
Please note S.E. stands for standard error where
S.E. = sqrt{ p1 * ( 1 - p1 )/n1 + p2 * (1 - p2)/n2) }
(p1 p2) is the observed difference between two sample proportions.
### zscore2sample post aggregator
* **`zscore2sample`**: calculate the z-score using two-sample z-test while converting binary variables (***e.g.*** success or not) to continuous variables (***e.g.*** conversion rate).
```json
{
"type": "zscore2sample",
"name": "<output_name>",
"successCount1": <post_aggregator> success count of sample 1,
"sample1Size": <post_aggregaror> sample 1 size,
"successCount2": <post_aggregator> success count of sample 2,
"sample2Size" : <post_aggregator> sample 2 size
}
```
Please note the post aggregator will be converting binary variables to continuous variables for two population proportions. Specifically
p1 = (successCount1) / (sample size 1)
p2 = (successCount2) / (sample size 2)
### pvalue2tailedZtest post aggregator
* **`pvalue2tailedZtest`**: calculate p-value of two-sided z-test from zscore
- ***pvalue2tailedZtest(zscore)*** - the input is a z-score which can be calculated using the zscore2sample post aggregator
```json
{
"type": "pvalue2tailedZtest",
"name": "<output_name>",
"zScore": <zscore post_aggregator>
}
```
## Example Usage
In this example, we use zscore2sample post aggregator to calculate z-score, and then feed the z-score to pvalue2tailedZtest post aggregator to calculate p-value.
A JSON query example can be as follows:
```json
{
...
"postAggregations" : {
"type" : "pvalue2tailedZtest",
"name" : "pvalue",
"zScore" :
{
"type" : "zscore2sample",
"name" : "zscore",
"successCount1" :
{ "type" : "constant",
"name" : "successCountFromPopulation1Sample",
"value" : 300
},
"sample1Size" :
{ "type" : "constant",
"name" : "sampleSizeOfPopulation1",
"value" : 500
},
"successCount2":
{ "type" : "constant",
"name" : "successCountFromPopulation2Sample",
"value" : 450
},
"sample2Size" :
{ "type" : "constant",
"name" : "sampleSizeOfPopulation2",
"value" : 600
}
}
}
}
```

View File

@ -40,6 +40,10 @@
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
<!-- Tests -->
<dependency>

View File

@ -24,6 +24,8 @@ import com.fasterxml.jackson.databind.module.SimpleModule;
import com.google.common.collect.ImmutableList;
import com.google.inject.Binder;
import io.druid.initialization.DruidModule;
import io.druid.query.aggregation.teststats.PvaluefromZscorePostAggregator;
import io.druid.query.aggregation.teststats.ZtestPostAggregator;
import io.druid.query.aggregation.variance.StandardDeviationPostAggregator;
import io.druid.query.aggregation.variance.VarianceAggregatorFactory;
import io.druid.query.aggregation.variance.VarianceFoldingAggregatorFactory;
@ -43,7 +45,9 @@ public class DruidStatsModule implements DruidModule
new SimpleModule().registerSubtypes(
VarianceAggregatorFactory.class,
VarianceFoldingAggregatorFactory.class,
StandardDeviationPostAggregator.class
StandardDeviationPostAggregator.class,
ZtestPostAggregator.class,
PvaluefromZscorePostAggregator.class
)
);
}

View File

@ -0,0 +1,167 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.teststats;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import io.druid.query.Queries;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.PostAggregatorIds;
import io.druid.query.cache.CacheKeyBuilder;
import org.apache.commons.math3.distribution.NormalDistribution;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
@JsonTypeName("pvalue2tailedZtest")
public class PvaluefromZscorePostAggregator implements PostAggregator
{
private final String name;
private final PostAggregator zScore;
@JsonCreator
public PvaluefromZscorePostAggregator(
@JsonProperty("name") String name,
@JsonProperty("zScore") PostAggregator zScore
)
{
Preconditions.checkNotNull(
name,
"Must have a valid, non-null post-aggregator"
);
Preconditions.checkNotNull(
zScore,
"Must have a valid, non-null post-aggregator"
);
this.name = name;
this.zScore = zScore;
}
@Override
public Set<String> getDependentFields()
{
Set<String> dependentFields = Sets.newHashSet();
dependentFields.addAll(zScore.getDependentFields());
return dependentFields;
}
@Override
public Comparator getComparator()
{
return ArithmeticPostAggregator.DEFAULT_COMPARATOR;
}
@Override
public Object compute(Map<String, Object> combinedAggregators)
{
double zScoreValue = ((Number) zScore.compute(combinedAggregators))
.doubleValue();
zScoreValue = Math.abs(zScoreValue);
return 2 * (1 - cumulativeProbability(zScoreValue));
}
private double cumulativeProbability(double x)
{
try {
NormalDistribution normDist = new NormalDistribution();
return normDist.cumulativeProbability(x);
}
catch (IllegalArgumentException ex) {
return Double.NaN;
}
}
@Override
@JsonProperty
public String getName()
{
return name;
}
@Override
public PostAggregator decorate(Map<String, AggregatorFactory> aggregators)
{
return new PvaluefromZscorePostAggregator(
name,
Iterables.getOnlyElement(Queries.decoratePostAggregators(
Collections.singletonList(zScore), aggregators))
);
}
@JsonProperty
public PostAggregator getZscore()
{
return zScore;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PvaluefromZscorePostAggregator that = (PvaluefromZscorePostAggregator) o;
if (!name.equals(that.name)) {
return false;
}
return (zScore.equals(that.zScore));
}
@Override
public int hashCode()
{
int result = name.hashCode();
result = 31 * result + zScore.hashCode();
return result;
}
@Override
public String toString()
{
return "PvaluefromZscorePostAggregator{" +
"name='" + name + '\'' +
", zScore=" + zScore + '}';
}
@Override
public byte[] getCacheKey()
{
return new CacheKeyBuilder(PostAggregatorIds.PVALUE_FROM_ZTEST)
.appendCacheable(zScore).build();
}
}

View File

@ -0,0 +1,243 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.teststats;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import io.druid.query.Queries;
import io.druid.query.aggregation.AggregatorFactory;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.ArithmeticPostAggregator;
import io.druid.query.aggregation.post.PostAggregatorIds;
import io.druid.query.cache.CacheKeyBuilder;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
/**
* 1. calculating zscore using two-sample Z-Test. IOW,
* using z-test statistic for testing the difference of
* two population proportions.
* 2. converting binary variables (e.g. success or not) to continuous variables (e.g. conversion rate).
* <p>
* Please refer to http://math.mercyhurst.edu/~griff/courses/m109/Lectures/old/Sum_06/sect8.1.pdf
* and http://facweb.cs.depaul.edu/sjost/csc423/documents/test-descriptions/indep-z.pdf
* for more details.
*/
@JsonTypeName("zscore2sample")
public class ZtestPostAggregator implements PostAggregator
{
private final String name;
private final PostAggregator successCount1;
private final PostAggregator sample1Size;
private final PostAggregator successCount2;
private final PostAggregator sample2Size;
@JsonCreator
public ZtestPostAggregator(
@JsonProperty("name") String name,
@JsonProperty("successCount1") PostAggregator successCount1,
@JsonProperty("sample1Size") PostAggregator sample1Size,
@JsonProperty("successCount2") PostAggregator successCount2,
@JsonProperty("sample2Size") PostAggregator sample2Size
)
{
Preconditions.checkNotNull(name, "Must have a valid, non-null post-aggregator name");
Preconditions.checkNotNull(successCount1, "success count from sample 1 can not be null");
Preconditions.checkNotNull(sample1Size, "sample size of population 1 can not null");
Preconditions.checkNotNull(successCount2, "success count from sample 2 can not be null");
Preconditions.checkNotNull(sample2Size, "sample size of population 2 can not be null");
this.name = name;
this.successCount1 = successCount1;
this.sample1Size = sample1Size;
this.successCount2 = successCount2;
this.sample2Size = sample2Size;
}
@Override
public Set<String> getDependentFields()
{
Set<String> dependentFields = Sets.newLinkedHashSet();
dependentFields.addAll(successCount1.getDependentFields());
dependentFields.addAll(sample1Size.getDependentFields());
dependentFields.addAll(successCount2.getDependentFields());
dependentFields.addAll(sample2Size.getDependentFields());
return dependentFields;
}
@Override
public Comparator getComparator()
{
return ArithmeticPostAggregator.DEFAULT_COMPARATOR;
}
@Override
public Object compute(Map<String, Object> combinedAggregators)
{
return zScoreTwoSamples(
((Number) successCount1.compute(combinedAggregators)).doubleValue(),
((Number) sample1Size.compute(combinedAggregators)).doubleValue(),
((Number) successCount2.compute(combinedAggregators)).doubleValue(),
((Number) sample2Size.compute(combinedAggregators)).doubleValue()
);
}
@Override
@JsonProperty
public String getName()
{
return name;
}
@Override
public ZtestPostAggregator decorate(Map<String, AggregatorFactory> aggregators)
{
return new ZtestPostAggregator(
name,
Iterables
.getOnlyElement(Queries.decoratePostAggregators(Collections.singletonList(successCount1), aggregators)),
Iterables.getOnlyElement(Queries.decoratePostAggregators(Collections.singletonList(sample1Size), aggregators)),
Iterables
.getOnlyElement(Queries.decoratePostAggregators(Collections.singletonList(successCount2), aggregators)),
Iterables.getOnlyElement(Queries.decoratePostAggregators(Collections.singletonList(sample2Size), aggregators))
);
}
/**
* 1. calculating zscore for two-sample Z test. IOW, using z-test statistic
* for testing the difference of two population proportions. 2. converting
* binary variables (e.g. success or not) to continuous variables (e.g.
* conversion rate).
*
* @param s1count - success count of population 1
* @param p1count - sample size of population 1
* @param s2count - the success count of population 2
* @param p2count - sample size of population 2
*/
private double zScoreTwoSamples(double s1count, double p1count, double s2count, double p2count)
{
double convertRate1;
double convertRate2;
Preconditions.checkState(s1count >= 0, "success count can't be negative.");
Preconditions.checkState(s2count >= 0, "success count can't be negative.");
Preconditions.checkState(p1count >= s1count, "sample size can't be smaller than the success count.");
Preconditions.checkState(p2count >= s2count, "sample size can't be smaller than the success count.");
try {
convertRate1 = s1count / p1count;
convertRate2 = s2count / p2count;
return (convertRate1 - convertRate2) /
Math.sqrt((convertRate1 * (1 - convertRate1) / p1count) +
(convertRate2 * (1 - convertRate2) / p2count));
}
catch (IllegalArgumentException ex) {
return 0;
}
}
@JsonProperty
public PostAggregator getSuccessCount1()
{
return successCount1;
}
@JsonProperty
public PostAggregator getSample1Size()
{
return sample1Size;
}
@JsonProperty
public PostAggregator getSuccessCount2()
{
return successCount2;
}
@JsonProperty
public PostAggregator getSample2Size()
{
return sample2Size;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ZtestPostAggregator that = (ZtestPostAggregator) o;
if (!name.equals(that.name)) {
return false;
}
return (successCount1.equals(that.successCount1) &&
sample1Size.equals(that.sample1Size) &&
successCount2.equals(that.successCount2) && sample2Size.equals(that.sample2Size));
}
@Override
public int hashCode()
{
int result = name.hashCode();
result = 31 * result + successCount1.hashCode();
result = 31 * result + sample1Size.hashCode();
result = 31 * result + successCount2.hashCode();
result = 31 * result + sample2Size.hashCode();
return result;
}
@Override
public String toString()
{
return "ZtestPostAggregator{" +
"name='" + name + '\'' +
", successCount1='" + successCount1 + '\'' +
", sample1Size='" + sample1Size + '\'' +
", successCount2='" + successCount2 + '\'' +
", sample2size='" + sample2Size +
'}';
}
@Override
public byte[] getCacheKey()
{
return new CacheKeyBuilder(
PostAggregatorIds.ZTEST)
.appendCacheable(successCount1)
.appendCacheable(sample1Size)
.appendCacheable(successCount2)
.appendCacheable(sample2Size)
.build();
}
}

View File

@ -0,0 +1,59 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.teststats;
import com.google.common.collect.ImmutableMap;
import io.druid.jackson.DefaultObjectMapper;
import io.druid.query.aggregation.post.ConstantPostAggregator;
import org.junit.Assert;
import org.junit.Test;
public class PvaluefromZscorePostAggregatorTest
{
PvaluefromZscorePostAggregator pvaluefromZscorePostAggregator;
ConstantPostAggregator zscore;
@Test
public void testPvaluefromZscorePostAggregator() throws Exception
{
zscore = new ConstantPostAggregator("zscore", -1783.8762354220219);
pvaluefromZscorePostAggregator = new PvaluefromZscorePostAggregator("pvalue", zscore);
double pvalue = ((Number) pvaluefromZscorePostAggregator.compute(ImmutableMap.of(
"zscore",
-1783.8762354220219
))).doubleValue();
/* Assert P-value is positive and very small */
Assert.assertTrue(pvalue >= 0 && pvalue < 0.00001);
}
@Test
public void testSerde() throws Exception
{
DefaultObjectMapper mapper = new DefaultObjectMapper();
PvaluefromZscorePostAggregator postAggregator1 = mapper.readValue(mapper.writeValueAsString(
pvaluefromZscorePostAggregator), PvaluefromZscorePostAggregator.class);
Assert.assertEquals(pvaluefromZscorePostAggregator, postAggregator1);
}
}

View File

@ -0,0 +1,87 @@
/*
* Licensed to Metamarkets Group Inc. (Metamarkets) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. Metamarkets licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package io.druid.query.aggregation.teststats;
import com.google.common.collect.Lists;
import io.druid.jackson.DefaultObjectMapper;
import io.druid.query.aggregation.PostAggregator;
import io.druid.query.aggregation.post.ConstantPostAggregator;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ZtestPostAggregatorTest
{
ZtestPostAggregator ztestPostAggregator;
@Test
public void testZtestPostAggregator() throws Exception
{
ConstantPostAggregator successCount1, sample1Size, successCount2, sample2Size;
successCount1 = new ConstantPostAggregator("successCountPopulation1", 39244);
sample1Size = new ConstantPostAggregator("sampleSizePopulation1", 394298);
successCount2 = new ConstantPostAggregator("successCountPopulation2", 8991275);
sample2Size = new ConstantPostAggregator("sampleSizePopulation2", 9385573);
List<PostAggregator> postAggregatorList;
postAggregatorList = Lists.newArrayList(
successCount1,
sample1Size,
successCount2,
sample2Size
);
Map<String, Object> metricValues = new HashMap<>();
for (PostAggregator pa : postAggregatorList) {
metricValues.put(pa.getName(), ((ConstantPostAggregator) pa).getConstantValue());
}
ztestPostAggregator = new ZtestPostAggregator(
"zscore",
successCount1,
sample1Size,
successCount2,
sample2Size
);
double zscore = ((Number) ztestPostAggregator.compute(metricValues)).doubleValue();
Assert.assertEquals(-1783.8762354220219,
zscore, 0.0001
);
}
@Test
public void testSerde() throws Exception
{
DefaultObjectMapper mapper = new DefaultObjectMapper();
ZtestPostAggregator postAggregator1 =
mapper.readValue(
mapper.writeValueAsString(ztestPostAggregator),
ZtestPostAggregator.class
);
Assert.assertEquals(ztestPostAggregator, postAggregator1);
}
}

View File

@ -42,4 +42,6 @@ public class PostAggregatorIds
public static final byte DATA_SKETCHES_SKETCH_SET = 18;
public static final byte VARIANCE_STANDARD_DEVIATION = 19;
public static final byte FINALIZING_FIELD_ACCESS = 20;
public static final byte ZTEST = 21;
public static final byte PVALUE_FROM_ZTEST = 22;
}