diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java index 8c4d46d1c14..f207eeb4419 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java @@ -69,16 +69,18 @@ class SolrAggregate extends Aggregate implements SolrRel { for(Pair namedAggCall : getNamedAggCalls()) { - AggregateCall aggCall = namedAggCall.getKey(); Pair metric = toSolrMetric(implementor, aggCall, inNames); implementor.addReverseAggMapping(namedAggCall.getValue(), metric.getKey().toLowerCase(Locale.ROOT)+"("+metric.getValue()+")"); implementor.addMetricPair(namedAggCall.getValue(), metric.getKey(), metric.getValue()); + /* if(aggCall.getName() == null) { + System.out.println("AGG:"+namedAggCall.getValue()+":"+ aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")"); implementor.addFieldMapping(namedAggCall.getValue(), - aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")"); + aggCall.getAggregation().getName() + "(" + inNames.get(aggCall.getArgList().get(0)) + ")"); } + */ } for(int group : getGroupSet()) { diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrEnumerator.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrEnumerator.java index be6046c98fe..7ba3838ce79 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrEnumerator.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrEnumerator.java @@ -43,6 +43,7 @@ class SolrEnumerator implements Enumerator { * @param fields Fields to get from each Tuple */ SolrEnumerator(TupleStream tupleStream, List> fields) { + this.tupleStream = tupleStream; try { this.tupleStream.open(); diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrProject.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrProject.java index c4217f27e38..bd36ba8e4b1 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrProject.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrProject.java @@ -58,7 +58,7 @@ class SolrProject extends Project implements SolrRel { for (Pair pair : getNamedProjects()) { final String name = pair.right; final String expr = pair.left.accept(translator); - implementor.addFieldMapping(name, expr); + implementor.addFieldMapping(name, expr, false); } } } diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java index d4de2c68c37..370de16d886 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java @@ -47,9 +47,11 @@ interface SolrRel extends RelNode { RelOptTable table; SolrTable solrTable; - void addFieldMapping(String key, String val) { - if(key != null && !fieldMappings.containsKey(key)) { - this.fieldMappings.put(key, val); + void addFieldMapping(String key, String val, boolean overwrite) { + if(key != null) { + if(overwrite || !fieldMappings.containsKey(key)) { + this.fieldMappings.put(key, val); + } } } @@ -83,7 +85,7 @@ interface SolrRel extends RelNode { String metricIdentifier = metric.toLowerCase(Locale.ROOT) + "(" + column + ")"; if(outName != null) { - this.addFieldMapping(outName, metricIdentifier); + this.addFieldMapping(outName, metricIdentifier, true); } } diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrSchema.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrSchema.java index 83fa5379f41..20d01f33b34 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrSchema.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrSchema.java @@ -99,10 +99,14 @@ class SolrSchema extends AbstractSchema { case "string": type = typeFactory.createJavaType(String.class); break; + case "tint": + case "tlong": case "int": case "long": type = typeFactory.createJavaType(Long.class); break; + case "tfloat": + case "tdouble": case "float": case "double": type = typeFactory.createJavaType(Double.class); diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java index e313b440ce7..b7f552b6adf 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java @@ -128,7 +128,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { tupleStream = handleSelect(zk, collection, q, fields, orders, limit); } else { if(buckets.isEmpty()) { - tupleStream = handleStats(zk, collection, q, metricPairs); + tupleStream = handleStats(zk, collection, q, metricPairs, fields); } else { if(mapReduce) { tupleStream = handleGroupByMapReduce(zk, @@ -430,6 +430,11 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { final String limit, final String havingPredicate) throws IOException { + Map fmap = new HashMap(); + for(Map.Entry entry : fields) { + fmap.put(entry.getKey(), entry.getValue()); + } + int numWorkers = Integer.parseInt(properties.getProperty("numWorkers", "1")); Bucket[] buckets = buildBuckets(_buckets, fields); @@ -437,6 +442,13 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { if(metrics.length == 0) { return handleSelectDistinctMapReduce(zk, collection, properties, fields, query, orders, buckets, limit); + } else { + for(Metric metric : metrics) { + Class c = fmap.get(metric.getIdentifier()); + if(Long.class.equals(c)) { + metric.outputLong = true; + } + } } Set fieldSet = getFieldSet(metrics, fields); @@ -556,6 +568,12 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { final String lim, final String havingPredicate) throws IOException { + + Map fmap = new HashMap(); + for(Map.Entry f : fields) { + fmap.put(f.getKey(), f.getValue()); + } + ModifiableSolrParams solrParams = new ModifiableSolrParams(); solrParams.add(CommonParams.Q, query); @@ -564,6 +582,13 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { if(metrics.length == 0) { metrics = new Metric[1]; metrics[0] = new CountMetric(); + } else { + for(Metric metric : metrics) { + Class c = fmap.get(metric.getIdentifier()); + if(Long.class.equals(c)) { + metric.outputLong = true; + } + } } int limit = lim != null ? Integer.parseInt(lim) : 1000; @@ -767,12 +792,26 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { private TupleStream handleStats(String zk, String collection, String query, - List> metricPairs) { + List> metricPairs, + List> fields) { + Map fmap = new HashMap(); + for(Map.Entry entry : fields) { + fmap.put(entry.getKey(), entry.getValue()); + } + ModifiableSolrParams solrParams = new ModifiableSolrParams(); solrParams.add(CommonParams.Q, query); Metric[] metrics = buildMetrics(metricPairs, false).toArray(new Metric[0]); + + for(Metric metric : metrics) { + Class c = fmap.get(metric.getIdentifier()); + if(Long.class.equals(c)) { + metric.outputLong = true; + } + } + return new StatsStream(zk, collection, solrParams, metrics); } diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java index 10d4d4c9688..c97303b4a42 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java @@ -93,6 +93,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel { } private List generateFields(List queryFields, Map fieldMappings) { + if(fieldMappings.isEmpty()) { return queryFields; } else { diff --git a/solr/core/src/test/org/apache/solr/handler/TestSQLHandler.java b/solr/core/src/test/org/apache/solr/handler/TestSQLHandler.java index cb16f033a24..4889c9071a1 100644 --- a/solr/core/src/test/org/apache/solr/handler/TestSQLHandler.java +++ b/solr/core/src/test/org/apache/solr/handler/TestSQLHandler.java @@ -88,6 +88,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { testWhere(); testMixedCaseFields(); testBasicGrouping(); + testBasicGroupingTint(); testBasicGroupingFacets(); testSelectDistinct(); testSelectDistinctFacets(); @@ -669,7 +670,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { commit(); SolrParams sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", - "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2"); + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2"); SolrStream solrStream = new SolrStream(jetty.url, sParams); List tuples = getTuples(solrStream); @@ -684,7 +685,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(1); assert(tuple.get("str_s").equals("a")); @@ -692,10 +693,36 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) + sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", - "stmt", "select str_s as myString, count(*), sum(field_i) as mySum, min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2"); + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) as blah from collection1 where text='XXXX' group by str_s order by sum(field_i) asc limit 2"); + + solrStream = new SolrStream(jetty.url, sParams); + tuples = getTuples(solrStream); + + //Only two results because of the limit. + assert(tuples.size() == 2); + + tuple = tuples.get(0); + assert(tuple.get("str_s").equals("b")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) + assert(tuple.getDouble("blah") == 9.5); //avg(field_i) + + tuple = tuples.get(1); + assert(tuple.get("str_s").equals("a")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) + assert(tuple.getDouble("blah") == 13.5); //avg(field_i) + + sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", + "stmt", "select str_s as myString, count(*), sum(field_i) as mySum, min(field_i), max(field_i), avg(field_i) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2"); solrStream = new SolrStream(jetty.url, sParams); tuples = getTuples(solrStream); @@ -709,7 +736,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 19); assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(1); assert(tuple.get("myString").equals("a")); @@ -717,11 +744,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 27); assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " - + "cast(avg(1.0 * field_i) as float) from collection1 where (text='XXXX' AND NOT ((text='XXXY') AND (text='XXXY' OR text='XXXY'))) " + + "avg(field_i) from collection1 where (text='XXXX' AND NOT ((text='XXXY') AND (text='XXXY' OR text='XXXY'))) " + "group by str_s order by str_s desc"); solrStream = new SolrStream(jetty.url, sParams); @@ -746,7 +773,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10D); //avg(field_i) tuple = tuples.get(2); assert(tuple.get("str_s").equals("a")); @@ -755,11 +782,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", "stmt", "select str_s as myString, count(*) as myCount, sum(field_i) as mySum, min(field_i) as myMin, " - + "max(field_i) as myMax, cast(avg(1.0 * field_i) as float) as myAvg from collection1 " + + "max(field_i) as myMax, avg(field_i) as myAvg from collection1 " + "where (text='XXXX' AND NOT (text='XXXY')) group by str_s order by str_s desc"); solrStream = new SolrStream(jetty.url, sParams); @@ -784,7 +811,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 19); assert(tuple.getDouble("myMin") == 8); assert(tuple.getDouble("myMax") == 11); - assert(tuple.getDouble("myAvg") == 9.5D); + assert(tuple.getDouble("myAvg") == 10); tuple = tuples.get(2); assert(tuple.get("myString").equals("a")); @@ -792,10 +819,10 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 27); assert(tuple.getDouble("myMin") == 7); assert(tuple.getDouble("myMax") == 20); - assert(tuple.getDouble("myAvg") == 13.5D); + assert(tuple.getDouble("myAvg") == 14); sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", - "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) " + + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) " + "from collection1 where text='XXXX' group by str_s having sum(field_i) = 19"); solrStream = new SolrStream(jetty.url, sParams); @@ -809,10 +836,10 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", - "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), cast(avg(1.0 * field_i) as float) " + + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), avg(field_i) " + "from collection1 where text='XXXX' group by str_s having ((sum(field_i) = 19) AND (min(field_i) = 8))"); solrStream = new SolrStream(jetty.url, sParams); @@ -826,11 +853,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i) as mySum, min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + "having ((sum(field_i) = 19) AND (min(field_i) = 8))"); solrStream = new SolrStream(jetty.url, sParams); @@ -844,11 +871,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 19); assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + "having ((sum(field_i) = 19) AND (min(field_i) = 100))"); solrStream = new SolrStream(jetty.url, sParams); @@ -860,6 +887,60 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { } } + + private void testBasicGroupingTint() throws Exception { + try { + + CloudJettyRunner jetty = this.cloudJettys.get(0); + + del("*:*"); + + commit(); + + indexr("id", "1", "text", "XXXX XXXX", "str_s", "a", "field_ti", "7"); + indexr("id", "2", "text", "XXXX XXXX", "str_s", "b", "field_ti", "8"); + indexr("id", "3", "text", "XXXX XXXX", "str_s", "a", "field_ti", "20"); + indexr("id", "4", "text", "XXXX XXXX", "str_s", "b", "field_ti", "11"); + indexr("id", "5", "text", "XXXX XXXX", "str_s", "c", "field_ti", "30"); + indexr("id", "6", "text", "XXXX XXXX", "str_s", "c", "field_ti", "40"); + indexr("id", "7", "text", "XXXX XXXX", "str_s", "c", "field_ti", "50"); + indexr("id", "8", "text", "XXXX XXXX", "str_s", "c", "field_ti", "60"); + indexr("id", "9", "text", "XXXX XXXY", "str_s", "d", "field_ti", "70"); + commit(); + + SolrParams sParams = mapParams(CommonParams.QT, "/sql", + "stmt", "select str_s, count(*), sum(field_ti), min(field_ti), max(field_ti), avg(field_ti) from collection1 where text='XXXX' group by str_s order by sum(field_ti) asc limit 2"); + + SolrStream solrStream = new SolrStream(jetty.url, sParams); + List tuples = getTuples(solrStream); + + //Only two results because of the limit. + assert(tuples.size() == 2); + Tuple tuple; + + tuple = tuples.get(0); + assert(tuple.get("str_s").equals("b")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) + + tuple = tuples.get(1); + assert(tuple.get("str_s").equals("a")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) + + + + } finally { + delete(); + } + } + private void testSelectDistinctFacets() throws Exception { try { CloudJettyRunner jetty = this.cloudJettys.get(0); @@ -1506,6 +1587,35 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + + sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "facet", + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + + "order by sum(field_i) asc limit 2"); + + solrStream = new SolrStream(jetty.url, sParams); + tuples = getTuples(solrStream); + + //Only two results because of the limit. + assert(tuples.size() == 2); + + tuple = tuples.get(0); + assert(tuple.get("str_s").equals("b")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) + + tuple = tuples.get(1); + assert(tuple.get("str_s").equals("a")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) + + sParams = mapParams(CommonParams.QT, "/sql", "aggregationMode", "facet", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + "cast(avg(1.0 * field_i) as float) from collection1 where (text='XXXX' AND NOT (text='XXXY')) " @@ -1667,7 +1777,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { SolrParams sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + "order by sum(field_i) asc limit 2"); SolrStream solrStream = new SolrStream(jetty.url, sParams); @@ -1684,7 +1794,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(1); assert(tuple.get("str_s").equals("a")); @@ -1692,12 +1802,41 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) + + + sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", + "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + + "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "order by sum(field_i) asc limit 2"); + + solrStream = new SolrStream(jetty.url, sParams); + tuples = getTuples(solrStream); + + //Only two results because of the limit. + assert(tuples.size() == 2); + + tuple = tuples.get(0); + assert(tuple.get("str_s").equals("b")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 9.5); //avg(field_i) + + tuple = tuples.get(1); + assert(tuple.get("str_s").equals("a")); + assert(tuple.getDouble("EXPR$1") == 2); //count(*) + assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) + assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) + assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) + assert(tuple.getDouble("EXPR$5") == 13.5); //avg(field_i) + sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i) as mySum, min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2"); + "avg(field_i) from collection1 where text='XXXX' group by str_s order by mySum asc limit 2"); solrStream = new SolrStream(jetty.url, sParams); tuples = getTuples(solrStream); @@ -1711,7 +1850,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 19); assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(1); assert(tuple.get("str_s").equals("a")); @@ -1719,12 +1858,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("mySum") == 27); assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by str_s desc"); + "avg(field_i) from collection1 where text='XXXX' group by str_s order by str_s desc"); solrStream = new SolrStream(jetty.url, sParams); tuples = getTuples(solrStream); @@ -1748,7 +1887,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(2); assert(tuple.get("str_s").equals("a")); @@ -1756,12 +1895,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s as myString, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s order by myString desc"); + "avg(field_i) from collection1 where text='XXXX' group by str_s order by myString desc"); solrStream = new SolrStream(jetty.url, sParams); tuples = getTuples(solrStream); @@ -1785,7 +1924,7 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) tuple = tuples.get(2); assert(tuple.get("myString").equals("a")); @@ -1793,12 +1932,12 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 27); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 7); //min(field_i) assert(tuple.getDouble("EXPR$4") == 20); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 13.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 14); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s having sum(field_i) = 19"); + "avg(field_i) from collection1 where text='XXXX' group by str_s having sum(field_i) = 19"); solrStream = new SolrStream(jetty.url, sParams); tuples = getTuples(solrStream); @@ -1811,11 +1950,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + "having ((sum(field_i) = 19) AND (min(field_i) = 8))"); solrStream = new SolrStream(jetty.url, sParams); @@ -1829,11 +1968,11 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assert(tuple.getDouble("EXPR$2") == 19); //sum(field_i) assert(tuple.getDouble("EXPR$3") == 8); //min(field_i) assert(tuple.getDouble("EXPR$4") == 11); //max(field_i) - assert(tuple.getDouble("EXPR$5") == 9.5D); //avg(field_i) + assert(tuple.getDouble("EXPR$5") == 10); //avg(field_i) sParams = mapParams(CommonParams.QT, "/sql", "numWorkers", "2", "aggregationMode", "map_reduce", "stmt", "select str_s, count(*), sum(field_i), min(field_i), max(field_i), " + - "cast(avg(1.0 * field_i) as float) from collection1 where text='XXXX' group by str_s " + + "avg(field_i) from collection1 where text='XXXX' group by str_s " + "having ((sum(field_i) = 19) AND (min(field_i) = 100))"); solrStream = new SolrStream(jetty.url, sParams); @@ -1933,6 +2072,45 @@ public class TestSQLHandler extends AbstractFullDistribZkTestBase { assertTrue(maxf == 10.0D); assertTrue(avgf == 5.5D); + + //Test without cast on average int field + sParams = mapParams(CommonParams.QT, "/sql", + "stmt", "select count(*) as myCount, sum(a_i) as mySum, min(a_i) as myMin, max(a_i) as myMax, " + + "avg(a_i) as myAvg, sum(a_f), min(a_f), max(a_f), avg(a_f) from collection1"); + + solrStream = new SolrStream(jetty.url, sParams); + + tuples = getTuples(solrStream); + + assert(tuples.size() == 1); + + //Test Long and Double Sums + + tuple = tuples.get(0); + + count = tuple.getDouble("myCount"); + sumi = tuple.getDouble("mySum"); + mini = tuple.getDouble("myMin"); + maxi = tuple.getDouble("myMax"); + avgi = tuple.getDouble("myAvg"); + assertTrue(tuple.get("myAvg") instanceof Long); + sumf = tuple.getDouble("EXPR$5"); //sum(a_f) + minf = tuple.getDouble("EXPR$6"); //min(a_f) + maxf = tuple.getDouble("EXPR$7"); //max(a_f) + avgf = tuple.getDouble("EXPR$8"); //avg(a_f) + + assertTrue(count == 10); + assertTrue(mini == 0.0D); + assertTrue(maxi == 14.0D); + assertTrue(sumi == 70); + assertTrue(avgi == 7); + assertTrue(sumf == 55.0D); + assertTrue(minf == 1.0D); + assertTrue(maxf == 10.0D); + assertTrue(avgf == 5.5D); + + + // Test where clause hits sParams = mapParams(CommonParams.QT, "/sql", "stmt", "select count(*), sum(a_i), min(a_i), max(a_i), cast(avg(1.0 * a_i) as float), sum(a_f), " + diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java index 94d937da566..0180764ff92 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java @@ -234,7 +234,6 @@ public class FacetStream extends TupleStream implements Expressible { this.zkHost = zkHost; this.params = params; this.buckets = buckets; - System.out.println("####### Bucket count:"+buckets.length); this.metrics = metrics; this.bucketSizeLimit = bucketSizeLimit; this.collection = collection; @@ -356,6 +355,7 @@ public class FacetStream extends TupleStream implements Expressible { NamedList response = cloudSolrClient.request(request, collection); getTuples(response, buckets, metrics); Collections.sort(tuples, getStreamSort()); + } catch (Exception e) { throw new IOException(e); } @@ -509,7 +509,11 @@ public class FacetStream extends TupleStream implements Expressible { String identifier = metric.getIdentifier(); if(!identifier.startsWith("count(")) { double d = (double)bucket.get("facet_"+m); - t.put(identifier, d); + if(metric.outputLong) { + t.put(identifier, Math.round(d)); + } else { + t.put(identifier, d); + } ++m; } else { long l = ((Number)bucket.get("count")).longValue(); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java index 65389028ee4..f6b5818be72 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java @@ -58,6 +58,7 @@ public class StatsStream extends TupleStream implements Expressible { private String collection; private boolean done; private boolean doCount; + private Map metricMap; protected transient SolrClientCache cache; protected transient CloudSolrClient cloudSolrClient; @@ -82,6 +83,10 @@ public class StatsStream extends TupleStream implements Expressible { this.params = params; this.metrics = metrics; this.collection = collection; + metricMap = new HashMap(); + for(Metric metric : metrics) { + metricMap.put(metric.getIdentifier(), metric); + } } public StatsStream(StreamExpression expression, StreamFactory factory) throws IOException{ @@ -321,7 +326,14 @@ public class StatsStream extends TupleStream implements Expressible { private void addStat(Map map, String field, String stat, Object val) { if(stat.equals("mean")) { - map.put("avg("+field+")", val); + String name = "avg("+field+")"; + Metric m = metricMap.get(name); + if(m.outputLong) { + Number num = (Number) val; + map.put(name, Math.round(num.doubleValue())); + } else { + map.put(name, val); + } } else { map.put(stat+"("+field+")", val); } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java index 03c037a73ee..14f93b81496 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java @@ -37,27 +37,36 @@ public class MeanMetric extends Metric { private long count; public MeanMetric(String columnName){ - init("avg", columnName); + init("avg", columnName, false); + } + + public MeanMetric(String columnName, boolean outputLong){ + init("avg", columnName, outputLong); } public MeanMetric(StreamExpression expression, StreamFactory factory) throws IOException{ // grab all parameters out String functionName = expression.getFunctionName(); String columnName = factory.getValueOperand(expression, 0); - + String outputLong = factory.getValueOperand(expression, 1); + + // validate expression contains only what we want. if(null == columnName){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(columnName)", expression, functionName)); } - if(1 != expression.getParameters().size()){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression)); + + boolean ol = false; + if(outputLong != null) { + ol = Boolean.parseBoolean(outputLong); } - init(functionName, columnName); + init(functionName, columnName, ol); } - private void init(String functionName, String columnName){ + private void init(String functionName, String columnName, boolean outputLong){ this.columnName = columnName; + this.outputLong = outputLong; setFunctionName(functionName); setIdentifier(functionName, "(", columnName, ")"); } @@ -75,25 +84,29 @@ public class MeanMetric extends Metric { } public Metric newInstance() { - return new MeanMetric(columnName); + return new MeanMetric(columnName, outputLong); } public String[] getColumns() { return new String[]{columnName}; } - public Double getValue() { + public Number getValue() { double dcount = (double)count; if(longSum == 0) { return doubleSum/dcount; - } else { - return longSum/dcount; + double mean = longSum/dcount; + if(outputLong) { + return Math.round(mean); + } else { + return mean; + } } } @Override public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { - return new StreamExpression(getFunctionName()).withParameter(columnName); + return new StreamExpression(getFunctionName()).withParameter(columnName).withParameter(Boolean.toString(outputLong)); } } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java index 582b54ae441..87f7852c526 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java @@ -30,6 +30,7 @@ public abstract class Metric implements Expressible { private UUID metricNodeId = UUID.randomUUID(); private String functionName; private String identifier; + public boolean outputLong; // This is only used for SQL in facet mode. public String getFunctionName(){ return functionName; diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionToExpessionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionToExpessionTest.java index 4ddf4ce8dce..0a597b7ab80 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionToExpessionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionToExpessionTest.java @@ -155,7 +155,7 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase { assertTrue(expressionString.contains("sort=\"a_f asc, a_i asc\"")); assertTrue(expressionString.contains("min(a_i)")); assertTrue(expressionString.contains("max(a_i)")); - assertTrue(expressionString.contains("avg(a_i)")); + assertTrue(expressionString.contains("avg(a_i,false)")); assertTrue(expressionString.contains("count(*)")); assertTrue(expressionString.contains("sum(a_i)")); @@ -274,8 +274,8 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase { assertTrue(expressionString.contains("min(a_f)")); assertTrue(expressionString.contains("max(a_i)")); assertTrue(expressionString.contains("max(a_f)")); - assertTrue(expressionString.contains("avg(a_i)")); - assertTrue(expressionString.contains("avg(a_f)")); + assertTrue(expressionString.contains("avg(a_i,false)")); + assertTrue(expressionString.contains("avg(a_f,false)")); assertTrue(expressionString.contains("count(*)")); } @@ -427,7 +427,7 @@ public class StreamExpressionToExpessionTest extends LuceneTestCase { metric = new MeanMetric(StreamExpressionParser.parse("avg(foo)"), factory); expressionString = metric.toExpression(factory).toString(); - assertEquals("avg(foo)", expressionString); + assertEquals("avg(foo,false)", expressionString); } @Test