diff --git a/solr/core/src/java/org/apache/solr/handler/sql/CalciteSolrDriver.java b/solr/core/src/java/org/apache/solr/handler/sql/CalciteSolrDriver.java index 35c9f9d0db8..3dd30cc0e9c 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/CalciteSolrDriver.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/CalciteSolrDriver.java @@ -30,7 +30,9 @@ import java.util.Properties; *

It accepts connect strings that start with "jdbc:calcitesolr:".

*/ public class CalciteSolrDriver extends Driver { - protected CalciteSolrDriver() { + public final static String CONNECT_STRING_PREFIX = "jdbc:calcitesolr:"; + + private CalciteSolrDriver() { super(); } @@ -38,11 +40,11 @@ public class CalciteSolrDriver extends Driver { new CalciteSolrDriver().register(); } + @Override protected String getConnectStringPrefix() { - return "jdbc:calcitesolr:"; + return CONNECT_STRING_PREFIX; } - @Override public Connection connect(String url, Properties info) throws SQLException { Connection connection = super.connect(url, info); diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java index 582b9c4e167..aed2e339801 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrAggregate.java @@ -24,7 +24,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; -import org.apache.solr.client.solrj.io.stream.metrics.*; +import org.apache.calcite.util.Pair; import java.util.*; @@ -32,6 +32,15 @@ import java.util.*; * Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr. */ class SolrAggregate extends Aggregate implements SolrRel { + private static final List SUPPORTED_AGGREGATIONS = Arrays.asList( + SqlStdOperatorTable.COUNT, + SqlStdOperatorTable.SUM, + SqlStdOperatorTable.SUM0, + SqlStdOperatorTable.MIN, + SqlStdOperatorTable.MAX, + SqlStdOperatorTable.AVG + ); + SolrAggregate( RelOptCluster cluster, RelTraitSet traitSet, @@ -58,12 +67,12 @@ class SolrAggregate extends Aggregate implements SolrRel { final List inNames = SolrRules.solrFieldNames(getInput().getRowType()); final List outNames = SolrRules.solrFieldNames(getRowType()); - List metrics = new ArrayList<>(); + List> metrics = new ArrayList<>(); Map fieldMappings = new HashMap<>(); for(AggregateCall aggCall : aggCalls) { - Metric metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList()); + Pair metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList()); metrics.add(metric); - fieldMappings.put(aggCall.getName(), metric.getIdentifier()); + fieldMappings.put(aggCall.getName(), metric.getKey().toLowerCase(Locale.ROOT) + "(" + metric.getValue() + ")"); } List buckets = new ArrayList<>(); @@ -78,22 +87,16 @@ class SolrAggregate extends Aggregate implements SolrRel { implementor.addFieldMappings(fieldMappings); } - private Metric toSolrMetric(SqlAggFunction aggregation, List inNames, List args) { + private Pair toSolrMetric(SqlAggFunction aggregation, List inNames, List args) { switch (args.size()) { case 0: - if(aggregation.equals(SqlStdOperatorTable.COUNT)) { - return new CountMetric(); + if (aggregation.equals(SqlStdOperatorTable.COUNT)) { + return new Pair<>(aggregation.getName(), "*"); } case 1: final String inName = inNames.get(args.get(0)); - if (aggregation.equals(SqlStdOperatorTable.SUM) || aggregation.equals(SqlStdOperatorTable.SUM0)) { - return new SumMetric(inName); - } else if (aggregation.equals(SqlStdOperatorTable.MIN)) { - return new MinMetric(inName); - } else if (aggregation.equals(SqlStdOperatorTable.MAX)) { - return new MaxMetric(inName); - } else if (aggregation.equals(SqlStdOperatorTable.AVG)) { - return new MeanMetric(inName); + if(SUPPORTED_AGGREGATIONS.contains(aggregation)) { + return new Pair<>(aggregation.getName(), inName); } default: throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames); diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java index 9a11488fc79..5b67627db73 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrRel.java @@ -19,7 +19,7 @@ package org.apache.solr.handler.sql; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.rel.RelNode; -import org.apache.solr.client.solrj.io.stream.metrics.Metric; +import org.apache.calcite.util.Pair; import java.util.ArrayList; import java.util.HashMap; @@ -42,7 +42,7 @@ interface SolrRel extends RelNode { String limitValue = null; final List order = new ArrayList<>(); final List buckets = new ArrayList<>(); - final List metrics = new ArrayList<>(); + final List> metricPairs = new ArrayList<>(); RelOptTable table; SolrTable solrTable; @@ -68,8 +68,8 @@ interface SolrRel extends RelNode { this.buckets.addAll(buckets); } - void addMetrics(List metrics) { - this.metrics.addAll(metrics); + void addMetrics(List> metrics) { + this.metricPairs.addAll(metrics); } void setLimit(String limit) { diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java index 167dc95a2e5..753e9f835ad 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrTable.java @@ -27,13 +27,14 @@ import org.apache.calcite.rel.type.RelProtoDataType; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.schema.impl.AbstractTableQueryable; +import org.apache.calcite.util.Pair; import org.apache.solr.client.solrj.io.stream.CloudSolrStream; import org.apache.solr.client.solrj.io.stream.RollupStream; import org.apache.solr.client.solrj.io.stream.StatsStream; import org.apache.solr.client.solrj.io.stream.TupleStream; -import org.apache.solr.client.solrj.io.stream.metrics.Bucket; -import org.apache.solr.client.solrj.io.stream.metrics.Metric; +import org.apache.solr.client.solrj.io.stream.metrics.*; import org.apache.solr.common.params.CommonParams; +import org.apache.solr.update.VersionInfo; import java.io.IOException; import java.util.*; @@ -43,7 +44,7 @@ import java.util.*; */ class SolrTable extends AbstractQueryableTable implements TranslatableTable { private static final String DEFAULT_QUERY = "*:*"; - private static final String DEFAULT_VERSION_FIELD = "_version_"; + private static final String DEFAULT_VERSION_FIELD = VersionInfo.VERSION_FIELD; private static final String DEFAULT_SCORE_FIELD = "score"; private final String collection; @@ -81,7 +82,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { */ private Enumerable query(final Properties properties, final List fields, final String query, final List order, final List buckets, - final List metrics, final String limit) { + final List> metricPairs, final String limit) { // SolrParams should be a ModifiableParams instead of a map Map solrParams = new HashMap<>(); solrParams.put(CommonParams.OMIT_HEADER, "true"); @@ -96,10 +97,20 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { List fieldsList = new ArrayList<>(fields); List orderList = new ArrayList<>(order); + List metrics = buildMetrics(metricPairs); + if (!metrics.isEmpty()) { for(String bucket : buckets) { orderList.add(bucket + " desc"); } + + for(Metric metric : metrics) { + for(String column : metric.getColumns()) { + if (!fieldsList.contains(column)) { + fieldsList.add(column); + } + } + } } if (orderList.isEmpty()) { @@ -174,6 +185,32 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { }; } + private List buildMetrics(List> metricPairs) { + List metrics = new ArrayList<>(metricPairs.size()); + for(Pair metricPair : metricPairs) { + metrics.add(getMetric(metricPair)); + } + return metrics; + } + + private Metric getMetric(Pair metricPair) { + switch (metricPair.getKey()) { + case "COUNT": + return new CountMetric(metricPair.getValue()); + case "SUM": + case "$SUM0": + return new SumMetric(metricPair.getValue()); + case "MIN": + return new MinMetric(metricPair.getValue()); + case "MAX": + return new MaxMetric(metricPair.getValue()); + case "AVG": + return new MeanMetric(metricPair.getValue()); + default: + throw new IllegalArgumentException(metricPair.getKey()); + } + } + public Queryable asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) { return new SolrQueryable<>(queryProvider, schema, this, tableName); } @@ -190,7 +227,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { } public Enumerator enumerator() { - //noinspection unchecked + @SuppressWarnings("unchecked") final Enumerable enumerable = (Enumerable) getTable().query(getProperties()); return enumerable.enumerator(); } @@ -209,8 +246,8 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable { */ @SuppressWarnings("UnusedDeclaration") public Enumerable query(List fields, String query, List order, List buckets, - List metrics, String limit) { - return getTable().query(getProperties(), fields, query, order, buckets, metrics, limit); + List> metricPairs, String limit) { + return getTable().query(getProperties(), fields, query, order, buckets, metricPairs, limit); } } } diff --git a/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java b/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java index 4cdc92b6c5a..e6aea05dbb7 100644 --- a/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java +++ b/solr/core/src/java/org/apache/solr/handler/sql/SolrToEnumerableConverter.java @@ -23,14 +23,13 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.linq4j.tree.MethodCallExpression; import org.apache.calcite.plan.*; -import org.apache.calcite.prepare.CalcitePrepareImpl; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.convert.ConverterImpl; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.runtime.Hook; import org.apache.calcite.util.BuiltInMethod; -import org.apache.solr.client.solrj.io.stream.metrics.Metric; +import org.apache.calcite.util.Pair; import java.util.ArrayList; import java.util.List; @@ -55,7 +54,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel { } public Result implement(EnumerableRelImplementor implementor, Prefer pref) { - // Generates a call to "query" with the appropriate fields and filterQueries + // Generates a call to "query" with the appropriate fields final BlockBuilder list = new BlockBuilder(); final SolrRel.Implementor solrImplementor = new SolrRel.Implementor(); solrImplementor.visitChild(0, getInput()); @@ -64,17 +63,14 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel { final Expression table = list.append("table", solrImplementor.table.getExpression(SolrTable.SolrQueryable.class)); final Expression fields = list.append("fields", constantArrayList(generateFields(SolrRules.solrFieldNames(rowType), solrImplementor.fieldMappings), String.class)); - final Expression filterQueries = list.append("query", Expressions.constant(solrImplementor.query, String.class)); + final Expression query = list.append("query", Expressions.constant(solrImplementor.query, String.class)); final Expression order = list.append("order", constantArrayList(solrImplementor.order, String.class)); final Expression buckets = list.append("buckets", constantArrayList(solrImplementor.buckets, String.class)); - final Expression metrics = list.append("metrics", constantArrayList(solrImplementor.metrics, Metric.class)); + final Expression metricPairs = list.append("metricPairs", constantArrayList(solrImplementor.metricPairs, Pair.class)); final Expression limit = list.append("limit", Expressions.constant(solrImplementor.limitValue)); Expression enumerable = list.append("enumerable", Expressions.call(table, SolrMethod.SOLR_QUERYABLE_QUERY.method, - fields, filterQueries, order, buckets, metrics, limit)); - if (CalcitePrepareImpl.DEBUG) { - System.out.println("Solr: " + filterQueries); - } - Hook.QUERY_PLAN.run(filterQueries); + fields, query, order, buckets, metricPairs, limit)); + Hook.QUERY_PLAN.run(query); list.add(Expressions.return_(null, enumerable)); return implementor.result(physType, list.toBlock()); } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java index 0e8cbb057b8..61b83398e30 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java @@ -24,40 +24,50 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; public class CountMetric extends Metric { + private String columnName; private long count; - - public CountMetric(){ - init("count"); + + public CountMetric() { + this("*"); + } + + public CountMetric(String columnName) { + init("count", columnName); } public CountMetric(StreamExpression expression, StreamFactory factory) throws IOException{ // grab all parameters out String functionName = expression.getFunctionName(); String columnName = factory.getValueOperand(expression, 0); - - // validate expression contains only what we want. - if(!"*".equals(columnName)){ - throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(*)", expression, functionName)); - } + if(1 != expression.getParameters().size()){ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression)); } - - init(functionName); - + + init(functionName, columnName); } public String[] getColumns() { - return new String[0]; + if(isAllColumns()) { + return new String[0]; + } + return new String[]{columnName}; } - - private void init(String functionName){ + + private void init(String functionName, String columnName){ + this.columnName = columnName; setFunctionName(functionName); - setIdentifier(functionName, "(*)"); + setIdentifier(functionName, "(", columnName, ")"); + } + + private boolean isAllColumns() { + return "*".equals(this.columnName); } public void update(Tuple tuple) { - ++count; + if(isAllColumns() || tuple.get(columnName) != null) { + ++count; + } } public Long getValue() { @@ -65,10 +75,11 @@ public class CountMetric extends Metric { } public Metric newInstance() { - return new CountMetric(); + return new CountMetric(columnName); } + @Override public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { - return new StreamExpression(getFunctionName()).withParameter("*"); + return new StreamExpression(getFunctionName()).withParameter(columnName); } } \ No newline at end of file