SOLR-10341: SQL AVG function mis-interprets field type

This commit is contained in:
Joel Bernstein 2017-03-28 09:25:25 +01:00
parent 1a80e4d694
commit aa2b46a62a
13 changed files with 320 additions and 63 deletions

View File

@ -69,16 +69,18 @@ class SolrAggregate extends Aggregate implements SolrRel {
for(Pair<AggregateCall, String> namedAggCall : getNamedAggCalls()) {
AggregateCall aggCall = namedAggCall.getKey();
Pair<String, String> metric = toSolrMetric(implementor, aggCall, inNames);
implementor.addReverseAggMapping(namedAggCall.getValue(), metric.getKey().toLowerCase(Locale.ROOT)+"("+metric.getValue()+")");
implementor.addMetricPair(namedAggCall.getValue(), metric.getKey(), metric.getValue());
/*
if(aggCall.getName() == null) {
System.out.println("AGG:"+namedAggCall.getValue()+":"+ aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")");
implementor.addFieldMapping(namedAggCall.getValue(),
aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")");
aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")");
}
*/
}
for(int group : getGroupSet()) {

View File

@ -43,6 +43,7 @@ class SolrEnumerator implements Enumerator<Object> {
* @param fields Fields to get from each Tuple
*/
SolrEnumerator(TupleStream tupleStream, List<Map.Entry<String, Class>> fields) {
this.tupleStream = tupleStream;
try {
this.tupleStream.open();

View File

@ -58,7 +58,7 @@ class SolrProject extends Project implements SolrRel {
for (Pair<RexNode, String> pair : getNamedProjects()) {
final String name = pair.right;
final String expr = pair.left.accept(translator);
implementor.addFieldMapping(name, expr);
implementor.addFieldMapping(name, expr, false);
}
}
}

View File

@ -47,9 +47,11 @@ interface SolrRel extends RelNode {
RelOptTable table;
SolrTable solrTable;
void addFieldMapping(String key, String val) {
if(key != null && !fieldMappings.containsKey(key)) {
this.fieldMappings.put(key, val);
void addFieldMapping(String key, String val, boolean overwrite) {
if(key != null) {
if(overwrite || !fieldMappings.containsKey(key)) {
this.fieldMappings.put(key, val);
}
}
}
@ -83,7 +85,7 @@ interface SolrRel extends RelNode {
String metricIdentifier = metric.toLowerCase(Locale.ROOT) + "(" + column + ")";
if(outName != null) {
this.addFieldMapping(outName, metricIdentifier);
this.addFieldMapping(outName, metricIdentifier, true);
}
}

View File

@ -99,10 +99,14 @@ class SolrSchema extends AbstractSchema {
case "string":
type = typeFactory.createJavaType(String.class);
break;
case "tint":
case "tlong":
case "int":
case "long":
type = typeFactory.createJavaType(Long.class);
break;
case "tfloat":
case "tdouble":
case "float":
case "double":
type = typeFactory.createJavaType(Double.class);

View File

@ -128,7 +128,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
tupleStream = handleSelect(zk, collection, q, fields, orders, limit);
} else {
if(buckets.isEmpty()) {
tupleStream = handleStats(zk, collection, q, metricPairs);
tupleStream = handleStats(zk, collection, q, metricPairs, fields);
} else {
if(mapReduce) {
tupleStream = handleGroupByMapReduce(zk,
@ -430,6 +430,11 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
final String limit,
final String havingPredicate) throws IOException {
Map<String, Class> fmap = new HashMap();
for(Map.Entry<String, Class> entry : fields) {
fmap.put(entry.getKey(), entry.getValue());
}
int numWorkers = Integer.parseInt(properties.getProperty("numWorkers", "1"));
Bucket[] buckets = buildBuckets(_buckets, fields);
@ -437,6 +442,13 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
if(metrics.length == 0) {
return handleSelectDistinctMapReduce(zk, collection, properties, fields, query, orders, buckets, limit);
} else {
for(Metric metric : metrics) {
Class c = fmap.get(metric.getIdentifier());
if(Long.class.equals(c)) {
metric.outputLong = true;
}
}
}
Set<String> fieldSet = getFieldSet(metrics, fields);
@ -556,6 +568,12 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
final String lim,
final String havingPredicate) throws IOException {
Map<String, Class> fmap = new HashMap();
for(Map.Entry<String, Class> f : fields) {
fmap.put(f.getKey(), f.getValue());
}
ModifiableSolrParams solrParams = new ModifiableSolrParams();
solrParams.add(CommonParams.Q, query);
@ -564,6 +582,13 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
if(metrics.length == 0) {
metrics = new Metric[1];
metrics[0] = new CountMetric();
} else {
for(Metric metric : metrics) {
Class c = fmap.get(metric.getIdentifier());
if(Long.class.equals(c)) {
metric.outputLong = true;
}
}
}
int limit = lim != null ? Integer.parseInt(lim) : 1000;
@ -767,12 +792,26 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
private TupleStream handleStats(String zk,
String collection,
String query,
List<Pair<String, String>> metricPairs) {
List<Pair<String, String>> metricPairs,
List<Map.Entry<String, Class>> fields) {
Map<String, Class> fmap = new HashMap();
for(Map.Entry<String, Class> entry : fields) {
fmap.put(entry.getKey(), entry.getValue());
}
ModifiableSolrParams solrParams = new ModifiableSolrParams();
solrParams.add(CommonParams.Q, query);
Metric[] metrics = buildMetrics(metricPairs, false).toArray(new Metric[0]);
for(Metric metric : metrics) {
Class c = fmap.get(metric.getIdentifier());
if(Long.class.equals(c)) {
metric.outputLong = true;
}
}
return new StatsStream(zk, collection, solrParams, metrics);
}

View File

@ -93,6 +93,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
}
private List<String> generateFields(List<String> queryFields, Map<String, String> fieldMappings) {
if(fieldMappings.isEmpty()) {
return queryFields;
} else {

View File

@ -88,6 +88,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
testWhere();
testMixedCaseFields();
testBasicGrouping();
testBasicGroupingTint();
testBasicGroupingFacets();
testSelectDistinct();
testSelectDistinctFacets();
@ -669,7 +670,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
commit();
SolrParams sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2");
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2");
SolrStream solrStream = new SolrStream(jetty.url, sParams);
List<Tuple> tuples = getTuples(solrStream);
@ -684,7 +685,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
@ -692,10 +693,36 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s as myString, count(*), sum(field_i) as mySum, min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2");
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) as blah from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
//Only two results because of the limit.
assert(tuples.size() == 2);
tuple = tuples.get(0);
assert(tuple.get("str_s").equals("b"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("blah") == 9.5); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("blah") == 13.5); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s as myString, count(*), sum(field_i) as mySum, min(field_i), max(field_i), avg(field_i) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
@ -709,7 +736,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 19);
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("myString").equals("a"));
@ -717,11 +744,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 27);
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), "
+ "cast(avg(1.0 * field_i) as float) from collection1 where (text='XXXX' AND NOT ((text='XXXY') AND (text='XXXY' OR text='XXXY'))) "
+ "avg(field_i) from collection1 where (text='XXXX' AND NOT ((text='XXXY') AND (text='XXXY' OR text='XXXY'))) "
+ "group by str_s order by str_s desc");
solrStream = new SolrStream(jetty.url, sParams);
@ -746,7 +773,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10D); //avg(field_i)
tuple = tuples.get(2);
assert(tuple.get("str_s").equals("a"));
@ -755,11 +782,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s as myString, count(*) as myCount, sum(field_i) as mySum, min(field_i) as myMin, "
+ "max(field_i) as myMax, cast(avg(1.0 * field_i) as float) as myAvg from collection1 "
+ "max(field_i) as myMax, avg(field_i) as myAvg from collection1 "
+ "where (text='XXXX' AND NOT (text='XXXY')) group by str_s order by str_s desc");
solrStream = new SolrStream(jetty.url, sParams);
@ -784,7 +811,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 19);
assert(tuple.getDouble("myMin") == 8);
assert(tuple.getDouble("myMax") == 11);
assert(tuple.getDouble("myAvg") == 9.5D);
assert(tuple.getDouble("myAvg") == 10);
tuple = tuples.get(2);
assert(tuple.get("myString").equals("a"));
@ -792,10 +819,10 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 27);
assert(tuple.getDouble("myMin") == 7);
assert(tuple.getDouble("myMax") == 20);
assert(tuple.getDouble("myAvg") == 13.5D);
assert(tuple.getDouble("myAvg") == 14);
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) " +
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) " +
"from collection1 where text='XXXX' group by str_s having sum(field_i) = 19");
solrStream = new SolrStream(jetty.url, sParams);
@ -809,10 +836,10 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) " +
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) " +
"from collection1 where text='XXXX' group by str_s having ((sum(field_i) = 19) AND (min(field_i) = 8))");
solrStream = new SolrStream(jetty.url, sParams);
@ -826,11 +853,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i) as mySum, min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"having ((sum(field_i) = 19) AND (min(field_i) = 8))");
solrStream = new SolrStream(jetty.url, sParams);
@ -844,11 +871,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 19);
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"having ((sum(field_i) = 19) AND (min(field_i) = 100))");
solrStream = new SolrStream(jetty.url, sParams);
@ -860,6 +887,60 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
}
}
private void testBasicGroupingTint() throws Exception {
try {
CloudJettyRunner jetty = this.cloudJettys.get(0);
del("*:*");
commit();
indexr("id", "1", "text", "XXXX XXXX", "str_s", "a", "field_ti", "7");
indexr("id", "2", "text", "XXXX XXXX", "str_s", "b", "field_ti", "8");
indexr("id", "3", "text", "XXXX XXXX", "str_s", "a", "field_ti", "20");
indexr("id", "4", "text", "XXXX XXXX", "str_s", "b", "field_ti", "11");
indexr("id", "5", "text", "XXXX XXXX", "str_s", "c", "field_ti", "30");
indexr("id", "6", "text", "XXXX XXXX", "str_s", "c", "field_ti", "40");
indexr("id", "7", "text", "XXXX XXXX", "str_s", "c", "field_ti", "50");
indexr("id", "8", "text", "XXXX XXXX", "str_s", "c", "field_ti", "60");
indexr("id", "9", "text", "XXXX XXXY", "str_s", "d", "field_ti", "70");
commit();
SolrParams sParams = mapParams(CommonParams.QT, "/sql",
"stmt", "select str_s, count(*), sum(field_ti), min(field_ti), max(field_ti), avg(field_ti) from collection1 where text='XXXX' group by str_s order by sum(field_ti) asc limit 2");
SolrStream solrStream = new SolrStream(jetty.url, sParams);
List<Tuple> tuples = getTuples(solrStream);
//Only two results because of the limit.
assert(tuples.size() == 2);
Tuple tuple;
tuple = tuples.get(0);
assert(tuple.get("str_s").equals("b"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
} finally {
delete();
}
}
private void testSelectDistinctFacets() throws Exception {
try {
CloudJettyRunner jetty = this.cloudJettys.get(0);
@ -1506,6 +1587,35 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "facet",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"order by sum(field_i) asc limit 2");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
//Only two results because of the limit.
assert(tuples.size() == 2);
tuple = tuples.get(0);
assert(tuple.get("str_s").equals("b"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "facet",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), "
+ "cast(avg(1.0 * field_i) as float) from collection1 where (text='XXXX' AND NOT (text='XXXY')) "
@ -1667,7 +1777,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
SolrParams sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"order by sum(field_i) asc limit 2");
SolrStream solrStream = new SolrStream(jetty.url, sParams);
@ -1684,7 +1794,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
@ -1692,12 +1802,41 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"order by sum(field_i) asc limit 2");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
//Only two results because of the limit.
assert(tuples.size() == 2);
tuple = tuples.get(0);
assert(tuple.get("str_s").equals("b"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
assert(tuple.getDouble("EXPR$1") == 2); //count(*)
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i) as mySum, min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2");
"avg(field_i) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
@ -1711,7 +1850,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 19);
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(1);
assert(tuple.get("str_s").equals("a"));
@ -1719,12 +1858,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("mySum") == 27);
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by str_s desc");
"avg(field_i) from collection1 where text='XXXX' group by str_s order by str_s desc");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
@ -1748,7 +1887,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(2);
assert(tuple.get("str_s").equals("a"));
@ -1756,12 +1895,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s as myString, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by myString desc");
"avg(field_i) from collection1 where text='XXXX' group by str_s order by myString desc");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
@ -1785,7 +1924,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
tuple = tuples.get(2);
assert(tuple.get("myString").equals("a"));
@ -1793,12 +1932,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 7); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 20); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s having sum(field_i) = 19");
"avg(field_i) from collection1 where text='XXXX' group by str_s having sum(field_i) = 19");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
@ -1811,11 +1950,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"having ((sum(field_i) = 19) AND (min(field_i) = 8))");
solrStream = new SolrStream(jetty.url, sParams);
@ -1829,11 +1968,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i)
assert(tuple.getDouble("EXPR$3") == 8); //min(field_i)
assert(tuple.getDouble("EXPR$4") == 11); //max(field_i)
assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i)
assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i)
sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce",
"stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " +
"cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " +
"avg(field_i) from collection1 where text='XXXX' group by str_s " +
"having ((sum(field_i) = 19) AND (min(field_i) = 100))");
solrStream = new SolrStream(jetty.url, sParams);
@ -1933,6 +2072,45 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase {
assertTrue(maxf == 10.0D);
assertTrue(avgf == 5.5D);
//Test without cast on average int field
sParams = mapParams(CommonParams.QT, "/sql",
"stmt", "select count(*) as myCount, sum(a_i) as mySum, min(a_i) as myMin, max(a_i) as myMax, " +
"avg(a_i) as myAvg, sum(a_f), min(a_f), max(a_f), avg(a_f) from collection1");
solrStream = new SolrStream(jetty.url, sParams);
tuples = getTuples(solrStream);
assert(tuples.size() == 1);
//Test Long and Double Sums
tuple = tuples.get(0);
count = tuple.getDouble("myCount");
sumi = tuple.getDouble("mySum");
mini = tuple.getDouble("myMin");
maxi = tuple.getDouble("myMax");
avgi = tuple.getDouble("myAvg");
assertTrue(tuple.get("myAvg") instanceof Long);
sumf = tuple.getDouble("EXPR$5"); //sum(a_f)
minf = tuple.getDouble("EXPR$6"); //min(a_f)
maxf = tuple.getDouble("EXPR$7"); //max(a_f)
avgf = tuple.getDouble("EXPR$8"); //avg(a_f)
assertTrue(count == 10);
assertTrue(mini == 0.0D);
assertTrue(maxi == 14.0D);
assertTrue(sumi == 70);
assertTrue(avgi == 7);
assertTrue(sumf == 55.0D);
assertTrue(minf == 1.0D);
assertTrue(maxf == 10.0D);
assertTrue(avgf == 5.5D);
// Test where clause hits
sParams = mapParams(CommonParams.QT, "/sql",
"stmt", "select count(*), sum(a_i), min(a_i), max(a_i), cast(avg(1.0 * a_i) as float), sum(a_f), " +

View File

@ -234,7 +234,6 @@ public class FacetStream extends TupleStream implements Expressible {
this.zkHost = zkHost;
this.params = params;
this.buckets = buckets;
System.out.println("####### Bucket count:"+buckets.length);
this.metrics = metrics;
this.bucketSizeLimit = bucketSizeLimit;
this.collection = collection;
@ -356,6 +355,7 @@ public class FacetStream extends TupleStream implements Expressible {
NamedList response = cloudSolrClient.request(request, collection);
getTuples(response, buckets, metrics);
Collections.sort(tuples, getStreamSort());
} catch (Exception e) {
throw new IOException(e);
}
@ -509,7 +509,11 @@ public class FacetStream extends TupleStream implements Expressible {
String identifier = metric.getIdentifier();
if(!identifier.startsWith("count(")) {
double d = (double)bucket.get("facet_"+m);
t.put(identifier, d);
if(metric.outputLong) {
t.put(identifier, Math.round(d));
} else {
t.put(identifier, d);
}
++m;
} else {
long l = ((Number)bucket.get("count")).longValue();

View File

@ -58,6 +58,7 @@ public class StatsStream extends TupleStream implements Expressible {
private String collection;
private boolean done;
private boolean doCount;
private Map<String, Metric> metricMap;
protected transient SolrClientCache cache;
protected transient CloudSolrClient cloudSolrClient;
@ -82,6 +83,10 @@ public class StatsStream extends TupleStream implements Expressible {
this.params = params;
this.metrics = metrics;
this.collection = collection;
metricMap = new HashMap();
for(Metric metric : metrics) {
metricMap.put(metric.getIdentifier(), metric);
}
}
public StatsStream(StreamExpression expression, StreamFactory factory) throws IOException{
@ -321,7 +326,14 @@ public class StatsStream extends TupleStream implements Expressible {
private void addStat(Map<String, Object> map, String field, String stat, Object val) {
if(stat.equals("mean")) {
map.put("avg("+field+")", val);
String name = "avg("+field+")";
Metric m = metricMap.get(name);
if(m.outputLong) {
Number num = (Number) val;
map.put(name, Math.round(num.doubleValue()));
} else {
map.put(name, val);
}
} else {
map.put(stat+"("+field+")", val);
}

View File

@ -37,27 +37,36 @@ public class MeanMetric extends Metric {
private long count;
public MeanMetric(String columnName){
init("avg", columnName);
init("avg", columnName, false);
}
public MeanMetric(String columnName, boolean outputLong){
init("avg", columnName, outputLong);
}
public MeanMetric(StreamExpression expression, StreamFactory factory) throws IOException{
// grab all parameters out
String functionName = expression.getFunctionName();
String columnName = factory.getValueOperand(expression, 0);
String outputLong = factory.getValueOperand(expression, 1);
// validate expression contains only what we want.
if(null == columnName){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(columnName)", expression, functionName));
}
if(1 != expression.getParameters().size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression));
boolean ol = false;
if(outputLong != null) {
ol = Boolean.parseBoolean(outputLong);
}
init(functionName, columnName);
init(functionName, columnName, ol);
}
private void init(String functionName, String columnName){
private void init(String functionName, String columnName, boolean outputLong){
this.columnName = columnName;
this.outputLong = outputLong;
setFunctionName(functionName);
setIdentifier(functionName, "(", columnName, ")");
}
@ -75,25 +84,29 @@ public class MeanMetric extends Metric {
}
public Metric newInstance() {
return new MeanMetric(columnName);
return new MeanMetric(columnName, outputLong);
}
public String[] getColumns() {
return new String[]{columnName};
}
public Double getValue() {
public Number getValue() {
double dcount = (double)count;
if(longSum == 0) {
return doubleSum/dcount;
} else {
return longSum/dcount;
double mean = longSum/dcount;
if(outputLong) {
return Math.round(mean);
} else {
return mean;
}
}
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return new StreamExpression(getFunctionName()).withParameter(columnName);
return new StreamExpression(getFunctionName()).withParameter(columnName).withParameter(Boolean.toString(outputLong));
}
}

View File

@ -30,6 +30,7 @@ public abstract class Metric implements Expressible {
private UUID metricNodeId = UUID.randomUUID();
private String functionName;
private String identifier;
public boolean outputLong; // This is only used for SQL in facet mode.
public String getFunctionName(){
return functionName;

View File

@ -155,7 +155,7 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
assertTrue(expressionString.contains("sort=\"a_f asc, a_i asc\""));
assertTrue(expressionString.contains("min(a_i)"));
assertTrue(expressionString.contains("max(a_i)"));
assertTrue(expressionString.contains("avg(a_i)"));
assertTrue(expressionString.contains("avg(a_i,false)"));
assertTrue(expressionString.contains("count(*)"));
assertTrue(expressionString.contains("sum(a_i)"));
@ -274,8 +274,8 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
assertTrue(expressionString.contains("min(a_f)"));
assertTrue(expressionString.contains("max(a_i)"));
assertTrue(expressionString.contains("max(a_f)"));
assertTrue(expressionString.contains("avg(a_i)"));
assertTrue(expressionString.contains("avg(a_f)"));
assertTrue(expressionString.contains("avg(a_i,false)"));
assertTrue(expressionString.contains("avg(a_f,false)"));
assertTrue(expressionString.contains("count(*)"));
}
@ -427,7 +427,7 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase {
metric = new MeanMetric(StreamExpressionParser.parse("avg(foo)"), factory);
expressionString = metric.toExpression(factory).toString();
assertEquals("avg(foo)", expressionString);
assertEquals("avg(foo,false)", expressionString);
}
@Test