SOLR-10082: JSON Facet API, add stddev and variance functions

This commit is contained in:
yonik 2017-04-17 22:30:29 -04:00
parent d286864d80
commit 3145f781b3
7 changed files with 311 additions and 50 deletions

View File

@ -167,6 +167,10 @@ New Features
initializing CloudSolrClient would not work if you have collection aliases on older versions of Solr initializing CloudSolrClient would not work if you have collection aliases on older versions of Solr
server that doesn't support LISTALIASES. (Ishan Chattopadhyaya, Noble Paul) server that doesn't support LISTALIASES. (Ishan Chattopadhyaya, Noble Paul)
* SOLR-10082: Variance and Standard Deviation aggregators for the JSON Facet API.
Example: json.facet={x:"stddev(field1)", y:"variance(field2)"}
(Rustam Hashimov, yonik)
Optimizations Optimizations
---------------------- ----------------------

View File

@ -58,9 +58,11 @@ import org.apache.solr.search.facet.HLLAgg;
import org.apache.solr.search.facet.MaxAgg; import org.apache.solr.search.facet.MaxAgg;
import org.apache.solr.search.facet.MinAgg; import org.apache.solr.search.facet.MinAgg;
import org.apache.solr.search.facet.PercentileAgg; import org.apache.solr.search.facet.PercentileAgg;
import org.apache.solr.search.facet.StddevAgg;
import org.apache.solr.search.facet.SumAgg; import org.apache.solr.search.facet.SumAgg;
import org.apache.solr.search.facet.SumsqAgg; import org.apache.solr.search.facet.SumsqAgg;
import org.apache.solr.search.facet.UniqueAgg; import org.apache.solr.search.facet.UniqueAgg;
import org.apache.solr.search.facet.VarianceAgg;
import org.apache.solr.search.function.CollapseScoreFunction; import org.apache.solr.search.function.CollapseScoreFunction;
import org.apache.solr.search.function.OrdFieldSource; import org.apache.solr.search.function.OrdFieldSource;
import org.apache.solr.search.function.ReverseOrdFieldSource; import org.apache.solr.search.function.ReverseOrdFieldSource;
@ -931,14 +933,21 @@ public abstract class ValueSourceParser implements NamedListInitializedPlugin {
} }
}); });
/*** addParser("agg_variance", new ValueSourceParser() {
addParser("agg_stdev", new ValueSourceParser() {
@Override @Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError { public ValueSource parse(FunctionQParser fp) throws SyntaxError {
return null; return new VarianceAgg(fp.parseValueSource());
} }
}); });
addParser("agg_stddev", new ValueSourceParser() {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
return new StddevAgg(fp.parseValueSource());
}
});
/***
addParser("agg_multistat", new ValueSourceParser() { addParser("agg_multistat", new ValueSourceParser() {
@Override @Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError { public ValueSource parse(FunctionQParser fp) throws SyntaxError {

View File

@ -46,8 +46,7 @@ public abstract class SlotAcc implements Closeable {
this.fcontext = fcontext; this.fcontext = fcontext;
} }
public void setNextReader(LeafReaderContext readerContext) throws IOException { public void setNextReader(LeafReaderContext readerContext) throws IOException {}
}
public abstract void collect(int doc, int slot) throws IOException; public abstract void collect(int doc, int slot) throws IOException;
@ -83,7 +82,6 @@ public abstract class SlotAcc implements Closeable {
return count; return count;
} }
public abstract int compare(int slotA, int slotB); public abstract int compare(int slotA, int slotB);
public abstract Object getValue(int slotNum) throws IOException; public abstract Object getValue(int slotNum) throws IOException;
@ -101,8 +99,7 @@ public abstract class SlotAcc implements Closeable {
public abstract void resize(Resizer resizer); public abstract void resize(Resizer resizer);
@Override @Override
public void close() throws IOException { public void close() throws IOException {}
}
public static abstract class Resizer { public static abstract class Resizer {
public abstract int getNewSize(); public abstract int getNewSize();
@ -181,7 +178,6 @@ abstract class FuncSlotAcc extends SlotAcc {
} }
} }
// have a version that counts the number of times a Slot has been hit? (for avg... what else?) // have a version that counts the number of times a Slot has been hit? (for avg... what else?)
// TODO: make more sense to have func as the base class rather than double? // TODO: make more sense to have func as the base class rather than double?
@ -210,7 +206,6 @@ abstract class DoubleFuncSlotAcc extends FuncSlotAcc {
return Double.compare(result[slotA], result[slotB]); return Double.compare(result[slotA], result[slotB]);
} }
@Override @Override
public Object getValue(int slot) { public Object getValue(int slot) {
return result[slot]; return result[slot];
@ -261,8 +256,6 @@ abstract class IntSlotAcc extends SlotAcc {
} }
} }
class SumSlotAcc extends DoubleFuncSlotAcc { class SumSlotAcc extends DoubleFuncSlotAcc {
public SumSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) { public SumSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) {
super(values, fcontext, numSlots); super(values, fcontext, numSlots);
@ -287,8 +280,6 @@ class SumsqSlotAcc extends DoubleFuncSlotAcc {
} }
} }
class MinSlotAcc extends DoubleFuncSlotAcc { class MinSlotAcc extends DoubleFuncSlotAcc {
public MinSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) { public MinSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) {
super(values, fcontext, numSlots, Double.NaN); super(values, fcontext, numSlots, Double.NaN);
@ -324,7 +315,6 @@ class MaxSlotAcc extends DoubleFuncSlotAcc {
} }
class AvgSlotAcc extends DoubleFuncSlotAcc { class AvgSlotAcc extends DoubleFuncSlotAcc {
int[] counts; int[] counts;
@ -351,7 +341,8 @@ class AvgSlotAcc extends DoubleFuncSlotAcc {
} }
private double avg(double tot, int count) { private double avg(double tot, int count) {
return count==0 ? 0 : tot/count; // returns 0 instead of NaN.. todo - make configurable? if NaN, we need to handle comparisons though... return count == 0 ? 0 : tot / count; // returns 0 instead of NaN.. todo - make configurable? if NaN, we need to
// handle comparisons though...
} }
private double avg(int slot) { private double avg(int slot) {
@ -382,26 +373,151 @@ class AvgSlotAcc extends DoubleFuncSlotAcc {
} }
} }
class VarianceSlotAcc extends DoubleFuncSlotAcc {
int[] counts;
double[] sum;
public VarianceSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) {
super(values, fcontext, numSlots);
counts = new int[numSlots];
sum = new double[numSlots];
}
@Override
public void reset() {
super.reset();
Arrays.fill(counts, 0);
Arrays.fill(sum, 0);
}
@Override
public void resize(Resizer resizer) {
super.resize(resizer);
this.counts = resizer.resize(this.counts, 0);
this.sum = resizer.resize(this.sum, 0);
}
private double variance(double sumSq, double sum, int count) {
double val = count == 0 ? 0 : (sumSq / count) - Math.pow(sum / count, 2);
return val;
}
private double variance(int slot) {
return variance(result[slot], sum[slot], counts[slot]); // calc once and cache in result?
}
@Override
public int compare(int slotA, int slotB) {
return Double.compare(this.variance(slotA), this.variance(slotB));
}
@Override
public Object getValue(int slot) {
if (fcontext.isShard()) {
ArrayList lst = new ArrayList(3);
lst.add(counts[slot]);
lst.add(result[slot]);
lst.add(sum[slot]);
return lst;
} else {
return this.variance(slot);
}
}
@Override
public void collect(int doc, int slot) throws IOException {
double val = values.doubleVal(doc);
if (values.exists(doc)) {
counts[slot]++;
result[slot] += val * val;
sum[slot] += val;
}
}
}
class StddevSlotAcc extends DoubleFuncSlotAcc {
int[] counts;
double[] sum;
public StddevSlotAcc(ValueSource values, FacetContext fcontext, int numSlots) {
super(values, fcontext, numSlots);
counts = new int[numSlots];
sum = new double[numSlots];
}
@Override
public void reset() {
super.reset();
Arrays.fill(counts, 0);
Arrays.fill(sum, 0);
}
@Override
public void resize(Resizer resizer) {
super.resize(resizer);
this.counts = resizer.resize(this.counts, 0);
this.result = resizer.resize(this.result, 0);
}
private double stdDev(double sumSq, double sum, int count) {
double val = count == 0 ? 0 : Math.sqrt((sumSq / count) - Math.pow(sum / count, 2));
return val;
}
private double stdDev(int slot) {
return stdDev(result[slot], sum[slot], counts[slot]); // calc once and cache in result?
}
@Override
public int compare(int slotA, int slotB) {
return Double.compare(this.stdDev(slotA), this.stdDev(slotB));
}
@Override
public Object getValue(int slot) {
if (fcontext.isShard()) {
ArrayList lst = new ArrayList(3);
lst.add(counts[slot]);
lst.add(result[slot]);
lst.add(sum[slot]);
return lst;
} else {
return this.stdDev(slot);
}
}
@Override
public void collect(int doc, int slot) throws IOException {
double val = values.doubleVal(doc);
if (values.exists(doc)) {
counts[slot]++;
result[slot] += val * val;
sum[slot] += val;
}
}
}
abstract class CountSlotAcc extends SlotAcc { abstract class CountSlotAcc extends SlotAcc {
public CountSlotAcc(FacetContext fcontext) { public CountSlotAcc(FacetContext fcontext) {
super(fcontext); super(fcontext);
} }
public abstract void incrementCount(int slot, int count); public abstract void incrementCount(int slot, int count);
public abstract int getCount(int slot); public abstract int getCount(int slot);
} }
class CountSlotArrAcc extends CountSlotAcc { class CountSlotArrAcc extends CountSlotAcc {
int[] result; int[] result;
public CountSlotArrAcc(FacetContext fcontext, int numSlots) { public CountSlotArrAcc(FacetContext fcontext, int numSlots) {
super(fcontext); super(fcontext);
result = new int[numSlots]; result = new int[numSlots];
} }
@Override @Override
public void collect(int doc, int slotNum) { // TODO: count arrays can use fewer bytes based on the number of docs in the base set (that's the upper bound for single valued) - look at ttf? public void collect(int doc, int slotNum) { // TODO: count arrays can use fewer bytes based on the number of docs in
// the base set (that's the upper bound for single valued) - look at ttf?
result[slotNum]++; result[slotNum]++;
} }
@ -439,7 +555,6 @@ class CountSlotArrAcc extends CountSlotAcc {
} }
} }
class SortSlotAcc extends SlotAcc { class SortSlotAcc extends SlotAcc {
public SortSlotAcc(FacetContext fcontext) { public SortSlotAcc(FacetContext fcontext) {
super(fcontext); super(fcontext);

View File

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.facet;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.queries.function.ValueSource;
public class StddevAgg extends SimpleAggValueSource {
public StddevAgg(ValueSource vs) {
super("stddev", vs);
}
@Override
public SlotAcc createSlotAcc(FacetContext fcontext, int numDocs, int numSlots) throws IOException {
return new StddevSlotAcc(getArg(), fcontext, numSlots);
}
@Override
public FacetMerger createFacetMerger(Object prototype) {
return new Merger();
}
private static class Merger extends FacetDoubleMerger {
long count;
double sumSq;
double sum;
@Override
@SuppressWarnings("unchecked")
public void merge(Object facetResult, Context mcontext1) {
List<Number> numberList = (List<Number>)facetResult;
this.count += numberList.get(0).longValue();
this.sumSq += numberList.get(1).doubleValue();
this.sum += numberList.get(2).doubleValue();
}
@Override
public Object getMergedResult() {
return this.getDouble();
}
@Override
protected double getDouble() {
double val = count == 0 ? 0.0d : Math.sqrt((sumSq/count)-Math.pow(sum/count, 2));
return val;
}
};
}

View File

@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.search.facet;
import java.io.IOException;
import java.util.List;
import org.apache.lucene.queries.function.ValueSource;
public class VarianceAgg extends SimpleAggValueSource {
public VarianceAgg(ValueSource vs) {
super("variance", vs);
}
@Override
public SlotAcc createSlotAcc(FacetContext fcontext, int numDocs, int numSlots) throws IOException {
return new VarianceSlotAcc(getArg(), fcontext, numSlots);
}
@Override
public FacetMerger createFacetMerger(Object prototype) {
return new Merger();
}
private static class Merger extends FacetDoubleMerger {
long count;
double sumSq;
double sum;
@Override
@SuppressWarnings("unchecked")
public void merge(Object facetResult, Context mcontext1) {
List<Number> numberList = (List<Number>)facetResult;
this.count += numberList.get(0).longValue();
this.sumSq += numberList.get(1).doubleValue();
this.sum += numberList.get(2).doubleValue();
}
@Override
public Object getMergedResult() {
return this.getDouble();
}
@Override
protected double getDouble() {
double val = count == 0 ? 0.0d : (sumSq/count)-Math.pow(sum/count, 2);
return val;
}
};
}

View File

@ -1099,7 +1099,8 @@ public class QueryEqualityTest extends SolrTestCaseJ4 {
assertFuncEquals("agg_hll(foo_i)", "agg_hll(foo_i)"); assertFuncEquals("agg_hll(foo_i)", "agg_hll(foo_i)");
assertFuncEquals("agg_sumsq(foo_i)", "agg_sumsq(foo_i)"); assertFuncEquals("agg_sumsq(foo_i)", "agg_sumsq(foo_i)");
assertFuncEquals("agg_percentile(foo_i,50)", "agg_percentile(foo_i,50)"); assertFuncEquals("agg_percentile(foo_i,50)", "agg_percentile(foo_i,50)");
// assertFuncEquals("agg_stdev(foo_i)", "agg_stdev(foo_i)"); assertFuncEquals("agg_variance(foo_i)", "agg_variance(foo_i)");
assertFuncEquals("agg_stddev(foo_i)", "agg_stddev(foo_i)");
// assertFuncEquals("agg_multistat(foo_i)", "agg_multistat(foo_i)"); // assertFuncEquals("agg_multistat(foo_i)", "agg_multistat(foo_i)");
} }

View File

@ -529,6 +529,7 @@ public class TestJsonFacets extends SolrTestCaseHS {
" , f2:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'max(${num_d})'} } " + " , f2:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'max(${num_d})'} } " +
" , f3:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'unique(${where_s})'} } " + " , f3:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'unique(${where_s})'} } " +
" , f4:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'hll(${where_s})'} } " + " , f4:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'hll(${where_s})'} } " +
" , f5:{${terms} type:terms, field:'${cat_s}', sort:'x desc', facet:{x:'variance(${num_d})'} } " +
"}" "}"
) )
, "facets=={ 'count':6, " + , "facets=={ 'count':6, " +
@ -536,6 +537,7 @@ public class TestJsonFacets extends SolrTestCaseHS {
", f2:{ 'buckets':[{ val:'B', count:3, x:11.0 }, { val:'A', count:2, x:4.0 }]} " + ", f2:{ 'buckets':[{ val:'B', count:3, x:11.0 }, { val:'A', count:2, x:4.0 }]} " +
", f3:{ 'buckets':[{ val:'A', count:2, x:2 }, { val:'B', count:3, x:2 }]} " + ", f3:{ 'buckets':[{ val:'A', count:2, x:2 }, { val:'B', count:3, x:2 }]} " +
", f4:{ 'buckets':[{ val:'A', count:2, x:2 }, { val:'B', count:3, x:2 }]} " + ", f4:{ 'buckets':[{ val:'A', count:2, x:2 }, { val:'B', count:3, x:2 }]} " +
", f5:{ 'buckets':[{ val:'B', count:3, x:74.6666666666666 }, { val:'A', count:2, x:1.0 }]} " +
"}" "}"
); );
@ -845,19 +847,18 @@ public class TestJsonFacets extends SolrTestCaseHS {
); );
// stats at top level // stats at top level
client.testJQ(params(p, "q", "*:*" client.testJQ(params(p, "q", "*:*"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', avg2:'avg(def(${num_d},0))', min1:'min(${num_d})', max1:'max(${num_d})'" + , "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', avg2:'avg(def(${num_d},0))', min1:'min(${num_d})', max1:'max(${num_d})'" +
", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" + ", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" +
", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" + ", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" +
", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)' }" ", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)', variance:'variance(${num_d})', stddev:'stddev(${num_d})' }"
) )
, "facets=={ 'count':6, " + , "facets=={ 'count':6, " +
"sum1:3.0, sumsq1:247.0, avg1:0.6, avg2:0.5, min1:-9.0, max1:11.0" + "sum1:3.0, sumsq1:247.0, avg1:0.6, avg2:0.5, min1:-9.0, max1:11.0" +
", numwhere:2, unique_num_i:4, unique_num_d:5, unique_date:5" + ", numwhere:2, unique_num_i:4, unique_num_d:5, unique_date:5" +
", where_hll:2, hll_num_i:4, hll_num_d:5, hll_date:5" + ", where_hll:2, hll_num_i:4, hll_num_d:5, hll_date:5" +
", med:2.0, perc:[-9.0,2.0,11.0] }" ", med:2.0, perc:[-9.0,2.0,11.0], variance:49.04, stddev:7.002856560004639}"
); );
// stats at top level, no matches // stats at top level, no matches
@ -865,21 +866,20 @@ public class TestJsonFacets extends SolrTestCaseHS {
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})'" + , "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})'" +
", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" + ", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" +
", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" + ", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" +
", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)' }" ", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)', variance:'variance(${num_d})', stddev:'stddev(${num_d})' }"
) )
, "facets=={count:0 " + , "facets=={count:0 " +
"/* ,sum1:0.0, sumsq1:0.0, avg1:0.0, min1:'NaN', max1:'NaN', numwhere:0 */" + "\n// ,sum1:0.0, sumsq1:0.0, avg1:0.0, min1:'NaN', max1:'NaN', numwhere:0 \n" +
" }" " }"
); );
// stats at top level, matching documents, but no values in the field // stats at top level, matching documents, but no values in the field
// NOTE: this represents the current state of what is returned, not the ultimate desired state. // NOTE: this represents the current state of what is returned, not the ultimate desired state.
client.testJQ(params(p, "q", "id:3" client.testJQ(params(p, "q", "id:3"
, "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})'" + , "json.facet", "{ sum1:'sum(${num_d})', sumsq1:'sumsq(${num_d})', avg1:'avg(${num_d})', min1:'min(${num_d})', max1:'max(${num_d})'" +
", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" + ", numwhere:'unique(${where_s})', unique_num_i:'unique(${num_i})', unique_num_d:'unique(${num_d})', unique_date:'unique(${date})'" +
", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" + ", where_hll:'hll(${where_s})', hll_num_i:'hll(${num_i})', hll_num_d:'hll(${num_d})', hll_date:'hll(${date})'" +
", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)' }" ", med:'percentile(${num_d},50)', perc:'percentile(${num_d},0,50.0,100)', variance:'variance(${num_d})', stddev:'stddev(${num_d})' }"
) )
, "facets=={count:1 " + , "facets=={count:1 " +
",sum1:0.0," + ",sum1:0.0," +
@ -894,11 +894,12 @@ public class TestJsonFacets extends SolrTestCaseHS {
" where_hll:0," + " where_hll:0," +
" hll_num_i:0," + " hll_num_i:0," +
" hll_num_d:0," + " hll_num_d:0," +
" hll_date:0" + " hll_date:0," +
" variance:0.0," +
" stddev:0.0" +
" }" " }"
); );
// //
// tests on a multi-valued field with actual multiple values, just to ensure that we are // tests on a multi-valued field with actual multiple values, just to ensure that we are
// using a multi-valued method for the rest of the tests when appropriate. // using a multi-valued method for the rest of the tests when appropriate.