SOLR-7306: percentiles for new facet module

git-svn-id: https://svn.apache.org/repos/asf/lucene/dev/trunk@1669189 13f79535-47bb-0310-9956-ffa450edef68
This commit is contained in:
Yonik Seeley 2015-03-25 18:02:26 +00:00
parent 29c1de0fa8
commit 3a8d0c2f38
5 changed files with 341 additions and 4 deletions

View File

@ -200,6 +200,14 @@ New Features
* SOLR-6350: StatsComponent now supports Percentiles (Xu Zhang, hossman)
* SOLR-7306: Percentiles support for the new facet module. Percentiles
can be calculated for all facet buckets and field faceting can sort
by percentile values.
Examples:
json.facet={ median_age : "percentile(age,50)" }
json.facet={ salary_percentiles : "percentile(salary,25,50,75)" }
(yonik)
Bug Fixes
----------------------

View File

@ -45,6 +45,7 @@ import org.apache.solr.search.facet.AvgAgg;
import org.apache.solr.search.facet.CountAgg;
import org.apache.solr.search.facet.MaxAgg;
import org.apache.solr.search.facet.MinAgg;
import org.apache.solr.search.facet.PercentileAgg;
import org.apache.solr.search.facet.SumAgg;
import org.apache.solr.search.facet.SumsqAgg;
import org.apache.solr.search.facet.UniqueAgg;
@ -868,7 +869,7 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin {
}
});
addParser("agg_percentile", new PercentileAgg.Parser());
}

View File

@ -0,0 +1,218 @@
package org.apache.solr.search.facet;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.tdunning.math.stats.AVLTreeDigest;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.solr.search.FunctionQParser;
import org.apache.solr.search.SyntaxError;
import org.apache.solr.search.ValueSourceParser;
public class PercentileAgg extends SimpleAggValueSource {
List<Double> percentiles;
public PercentileAgg(ValueSource vs, List<Double> percentiles) {
super("percentile", vs);
this.percentiles = percentiles;
}
@Override
public SlotAcc createSlotAcc(FacetContext fcontext, int numDocs, int numSlots) throws IOException {
return new Acc(getArg(), fcontext, numSlots);
}
@Override
public FacetMerger createFacetMerger(Object prototype) {
return new Merger();
}
@Override
public boolean equals(Object o) {
if (!(o instanceof PercentileAgg)) return false;
PercentileAgg other = (PercentileAgg)o;
return this.arg.equals(other.arg) && this.percentiles.equals(other.percentiles);
}
@Override
public int hashCode() {
return super.hashCode() * 31 + percentiles.hashCode();
}
public static class Parser extends ValueSourceParser {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
List<Double> percentiles = new ArrayList<>();
ValueSource vs = fp.parseValueSource();
while (fp.hasMoreArguments()) {
double val = fp.parseDouble();
if (val<0 || val>100) {
throw new SyntaxError("requested percentile must be between 0 and 100. got " + val);
}
percentiles.add(val);
}
if (percentiles.isEmpty()) {
throw new SyntaxError("expected percentile(valsource,percent1[,percent2]*) EXAMPLE:percentile(myfield,50)");
}
return new PercentileAgg(vs, percentiles);
}
}
protected Object getValueFromDigest(AVLTreeDigest digest) {
if (digest == null) {
return null;
}
if (percentiles.size() == 1) {
return digest.quantile( percentiles.get(0) * 0.01 );
}
List<Double> lst = new ArrayList(percentiles.size());
for (Double percentile : percentiles) {
double val = digest.quantile( percentile * 0.01 );
lst.add( val );
}
return lst;
}
class Acc extends FuncSlotAcc {
protected AVLTreeDigest[] digests;
protected ByteBuffer buf;
protected double[] sortvals;
public Acc(ValueSource values, FacetContext fcontext, int numSlots) {
super(values, fcontext, numSlots);
digests = new AVLTreeDigest[numSlots];
}
public void collect(int doc, int slotNum) {
if (!values.exists(doc)) return;
double val = values.doubleVal(doc);
AVLTreeDigest digest = digests[slotNum];
if (digest == null) {
digests[slotNum] = digest = new AVLTreeDigest(100); // TODO: make compression configurable
}
digest.add(val);
}
@Override
public int compare(int slotA, int slotB) {
if (sortvals == null) {
fillSortVals();
}
return Double.compare(sortvals[slotA], sortvals[slotB]);
}
private void fillSortVals() {
sortvals = new double[ digests.length ];
double sortp = percentiles.get(0) * 0.01;
for (int i=0; i<digests.length; i++) {
AVLTreeDigest digest = digests[i];
if (digest == null) {
sortvals[i] = Double.NEGATIVE_INFINITY;
} else {
sortvals[i] = digest.quantile(sortp);
}
}
}
@Override
public Object getValue(int slotNum) throws IOException {
if (fcontext.isShard()) {
return getShardValue(slotNum);
}
if (sortvals != null && percentiles.size()==1) {
// we've already calculated everything we need
return sortvals[slotNum];
}
return getValueFromDigest( digests[slotNum] );
}
public Object getShardValue(int slot) throws IOException {
AVLTreeDigest digest = digests[slot];
if (digest == null) return null; // no values for this slot
digest.compress();
int sz = digest.byteSize();
if (buf == null || buf.capacity() < sz) {
buf = ByteBuffer.allocate(sz+(sz>>1)); // oversize by 50%
} else {
buf.clear();
}
digest.asSmallBytes(buf);
byte[] arr = Arrays.copyOf(buf.array(), buf.position());
return arr;
}
@Override
public void reset() {
digests = new AVLTreeDigest[digests.length];
sortvals = null;
}
}
class Merger extends FacetSortableMerger {
protected AVLTreeDigest digest;
protected Double sortVal;
@Override
public void merge(Object facetResult) {
byte[] arr = (byte[])facetResult;
AVLTreeDigest subDigest = AVLTreeDigest.fromBytes(ByteBuffer.wrap(arr));
if (digest == null) {
digest = subDigest;
} else {
digest.add(subDigest);
}
}
@Override
public Object getMergedResult() {
if (percentiles.size() == 1) return getSortVal();
return getValueFromDigest(digest);
}
@Override
public int compareTo(FacetSortableMerger other, FacetField.SortDirection direction) {
return Double.compare(getSortVal(), ((Merger) other).getSortVal());
}
private Double getSortVal() {
if (sortVal == null) {
sortVal = digest.quantile( percentiles.get(0) * 0.01 );
}
return sortVal;
}
}
}

View File

@ -956,6 +956,7 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
assertFuncEquals("agg_count()", "agg_count()");
assertFuncEquals("agg_unique(foo_i)", "agg_unique(foo_i)");
assertFuncEquals("agg_sumsq(foo_i)", "agg_sumsq(foo_i)");
assertFuncEquals("agg_percentile(foo_i,50)", "agg_percentile(foo_i,50)");
// assertFuncEquals("agg_stdev(foo_i)", "agg_stdev(foo_i)");
// assertFuncEquals("agg_multistat(foo_i)", "agg_multistat(foo_i)");
}

View File

@ -17,6 +17,7 @@ package org.apache.solr.search.facet;
* limitations under the License.
*/
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -24,6 +25,7 @@ import java.util.Comparator;
import java.util.List;
import java.util.Random;
import com.tdunning.math.stats.AVLTreeDigest;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.JSONTestUtil;
import org.apache.solr.SolrTestCaseHS;
@ -366,6 +368,30 @@ public class TestJsonFacets extends SolrTestCaseHS {
", f2:{ 'buckets':[{ val:'B', count:3, n1:-3.0}, { val:'A', count:2, n1:6.0 }]} }"
);
// percentiles 0,10,50,90,100
// catA: 2.0 2.2 3.0 3.8 4.0
// catB: -9.0 -8.2 -5.0 7.800000000000001 11.0
// all: -9.0 -7.3999999999999995 2.0 8.200000000000001 11.0
// test sorting by single percentile
client.testJQ(params(p, "q", "*:*"
, "json.facet", "{f1:{terms:{field:'${cat_s}', sort:'n1 desc', facet:{n1:'percentile(${num_d},50)'} }}" +
" , f2:{terms:{field:'${cat_s}', sort:'n1 asc', facet:{n1:'percentile(${num_d},50)'} }} }"
)
, "facets=={ 'count':6, " +
" f1:{ 'buckets':[{ val:'A', count:2, n1:3.0 }, { val:'B', count:3, n1:-5.0}]}" +
", f2:{ 'buckets':[{ val:'B', count:3, n1:-5.0}, { val:'A', count:2, n1:3.0 }]} }"
);
// test sorting by multiple percentiles (sort is by first)
client.testJQ(params(p, "q", "*:*"
, "json.facet", "{f1:{terms:{field:'${cat_s}', sort:'n1 desc', facet:{n1:'percentile(${num_d},50,0,100)'} }}" +
" , f2:{terms:{field:'${cat_s}', sort:'n1 asc', facet:{n1:'percentile(${num_d},50,0,100)'} }} }"
)
, "facets=={ 'count':6, " +
" f1:{ 'buckets':[{ val:'A', count:2, n1:[3.0,2.0,4.0] }, { val:'B', count:3, n1:[-5.0,-9.0,11.0] }]}" +
", f2:{ 'buckets':[{ val:'B', count:3, n1:[-5.0,-9.0,11.0]}, { val:'A', count:2, n1:[3.0,2.0,4.0] }]} }"
);
// test sorting by count/index order
client.testJQ(params(p, "q", "*:*"
, "json.facet", "{f1:{terms:{field:'${cat_s}', sort:'count desc' } }" +
@ -557,15 +583,15 @@ public class TestJsonFacets extends SolrTestCaseHS {
// stats at top level
client.testJQ(params(p, "q", "*:*"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})', numwhere:'unique(${where_s})' }"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})', numwhere:'unique(${where_s})', med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)' }"
)
, "facets=={ 'count':6, " +
"sum1:3.0, sumsq1:247.0, avg1:0.5, min1:-9.0, max1:11.0, numwhere:2 }"
"sum1:3.0, sumsq1:247.0, avg1:0.5, min1:-9.0, max1:11.0, numwhere:2, med:2.0, perc:[-9.0,2.0,11.0] }"
);
// stats at top level, no matches
client.testJQ(params(p, "q", "id:DOESNOTEXIST"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})', numwhere:'unique(${where_s})' }"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})', numwhere:'unique(${where_s})', med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)' }"
)
, "facets=={count:0 " +
"/* ,sum1:0.0, sumsq1:0.0, avg1:0.0, min1:'NaN', max1:'NaN', numwhere:0 */ }"
@ -671,4 +697,87 @@ public class TestJsonFacets extends SolrTestCaseHS {
doStats( client, params() );
}
/***
public void testPercentiles() {
AVLTreeDigest catA = new AVLTreeDigest(100);
catA.add(4);
catA.add(2);
AVLTreeDigest catB = new AVLTreeDigest(100);
catB.add(-9);
catB.add(11);
catB.add(-5);
AVLTreeDigest all = new AVLTreeDigest(100);
all.add(catA);
all.add(catB);
System.out.println(str(catA));
System.out.println(str(catB));
System.out.println(str(all));
// 2.0 2.2 3.0 3.8 4.0
// -9.0 -8.2 -5.0 7.800000000000001 11.0
// -9.0 -7.3999999999999995 2.0 8.200000000000001 11.0
}
private static String str(AVLTreeDigest digest) {
StringBuilder sb = new StringBuilder();
for (double d : new double[] {0,.1,.5,.9,1}) {
sb.append(" ").append(digest.quantile(d));
}
return sb.toString();
}
***/
/*** test code to ensure TDigest is working as we expect.
@Test
public void testTDigest() throws Exception {
AVLTreeDigest t1 = new AVLTreeDigest(100);
t1.add(10, 1);
t1.add(90, 1);
t1.add(50, 1);
System.out.println(t1.quantile(0.1));
System.out.println(t1.quantile(0.5));
System.out.println(t1.quantile(0.9));
assertEquals(t1.quantile(0.5), 50.0, 0.01);
AVLTreeDigest t2 = new AVLTreeDigest(100);
t2.add(130, 1);
t2.add(170, 1);
t2.add(90, 1);
System.out.println(t2.quantile(0.1));
System.out.println(t2.quantile(0.5));
System.out.println(t2.quantile(0.9));
AVLTreeDigest top = new AVLTreeDigest(100);
t1.compress();
ByteBuffer buf = ByteBuffer.allocate(t1.byteSize()); // upper bound
t1.asSmallBytes(buf);
byte[] arr1 = Arrays.copyOf(buf.array(), buf.position());
ByteBuffer rbuf = ByteBuffer.wrap(arr1);
top.add(AVLTreeDigest.fromBytes(rbuf));
System.out.println(top.quantile(0.1));
System.out.println(top.quantile(0.5));
System.out.println(top.quantile(0.9));
t2.compress();
ByteBuffer buf2 = ByteBuffer.allocate(t2.byteSize()); // upper bound
t2.asSmallBytes(buf2);
byte[] arr2 = Arrays.copyOf(buf2.array(), buf2.position());
ByteBuffer rbuf2 = ByteBuffer.wrap(arr2);
top.add(AVLTreeDigest.fromBytes(rbuf2));
System.out.println(top.quantile(0.1));
System.out.println(top.quantile(0.5));
System.out.println(top.quantile(0.9));
}
******/
}