Cleanup implementation

This commit is contained in:
Kevin Risden 2016-05-04 14:56:23 -05:00
parent a414d24684
commit 5daf6c40d8
6 changed files with 106 additions and 57 deletions

View File

@ -30,7 +30,9 @@ import java.util.Properties;
* <p>It accepts connect strings that start with "jdbc:calcitesolr:".</p> * <p>It accepts connect strings that start with "jdbc:calcitesolr:".</p>
*/ */
public class CalciteSolrDriver extends Driver { public class CalciteSolrDriver extends Driver {
protected CalciteSolrDriver() { public final static String CONNECT_STRING_PREFIX = "jdbc:calcitesolr:";
private CalciteSolrDriver() {
super(); super();
} }
@ -38,11 +40,11 @@ public class CalciteSolrDriver extends Driver {
new CalciteSolrDriver().register(); new CalciteSolrDriver().register();
} }
@Override
protected String getConnectStringPrefix() { protected String getConnectStringPrefix() {
return "jdbc:calcitesolr:"; return CONNECT_STRING_PREFIX;
} }
@Override @Override
public Connection connect(String url, Properties info) throws SQLException { public Connection connect(String url, Properties info) throws SQLException {
Connection connection = super.connect(url, info); Connection connection = super.connect(url, info);

View File

@ -24,7 +24,7 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet; import org.apache.calcite.util.ImmutableBitSet;
import org.apache.solr.client.solrj.io.stream.metrics.*; import org.apache.calcite.util.Pair;
import java.util.*; import java.util.*;
@ -32,6 +32,15 @@ import java.util.*;
* Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr. * Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr.
*/ */
class SolrAggregate extends Aggregate implements SolrRel { class SolrAggregate extends Aggregate implements SolrRel {
private static final List<SqlAggFunction> SUPPORTED_AGGREGATIONS = Arrays.asList(
SqlStdOperatorTable.COUNT,
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.SUM0,
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG
);
SolrAggregate( SolrAggregate(
RelOptCluster cluster, RelOptCluster cluster,
RelTraitSet traitSet, RelTraitSet traitSet,
@ -58,12 +67,12 @@ class SolrAggregate extends Aggregate implements SolrRel {
final List<String> inNames = SolrRules.solrFieldNames(getInput().getRowType()); final List<String> inNames = SolrRules.solrFieldNames(getInput().getRowType());
final List<String> outNames = SolrRules.solrFieldNames(getRowType()); final List<String> outNames = SolrRules.solrFieldNames(getRowType());
List<Metric> metrics = new ArrayList<>(); List<Pair<String, String>> metrics = new ArrayList<>();
Map<String, String> fieldMappings = new HashMap<>(); Map<String, String> fieldMappings = new HashMap<>();
for(AggregateCall aggCall : aggCalls) { for(AggregateCall aggCall : aggCalls) {
Metric metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList()); Pair<String, String> metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList());
metrics.add(metric); metrics.add(metric);
fieldMappings.put(aggCall.getName(), metric.getIdentifier()); fieldMappings.put(aggCall.getName(), metric.getKey().toLowerCase(Locale.ROOT) + "(" + metric.getValue() + ")");
} }
List<String> buckets = new ArrayList<>(); List<String> buckets = new ArrayList<>();
@ -78,22 +87,16 @@ class SolrAggregate extends Aggregate implements SolrRel {
implementor.addFieldMappings(fieldMappings); implementor.addFieldMappings(fieldMappings);
} }
private Metric toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) { private Pair<String, String> toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) {
switch (args.size()) { switch (args.size()) {
case 0: case 0:
if(aggregation.equals(SqlStdOperatorTable.COUNT)) { if (aggregation.equals(SqlStdOperatorTable.COUNT)) {
return new CountMetric(); return new Pair<>(aggregation.getName(), "*");
} }
case 1: case 1:
final String inName = inNames.get(args.get(0)); final String inName = inNames.get(args.get(0));
if (aggregation.equals(SqlStdOperatorTable.SUM) || aggregation.equals(SqlStdOperatorTable.SUM0)) { if(SUPPORTED_AGGREGATIONS.contains(aggregation)) {
return new SumMetric(inName); return new Pair<>(aggregation.getName(), inName);
} else if (aggregation.equals(SqlStdOperatorTable.MIN)) {
return new MinMetric(inName);
} else if (aggregation.equals(SqlStdOperatorTable.MAX)) {
return new MaxMetric(inName);
} else if (aggregation.equals(SqlStdOperatorTable.AVG)) {
return new MeanMetric(inName);
} }
default: default:
throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames); throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames);

View File

@ -19,7 +19,7 @@ package org.apache.solr.handler.sql;
import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelNode;
import org.apache.solr.client.solrj.io.stream.metrics.Metric; import org.apache.calcite.util.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -42,7 +42,7 @@ interface SolrRel extends RelNode {
String limitValue = null; String limitValue = null;
final List<String> order = new ArrayList<>(); final List<String> order = new ArrayList<>();
final List<String> buckets = new ArrayList<>(); final List<String> buckets = new ArrayList<>();
final List<Metric> metrics = new ArrayList<>(); final List<Pair<String, String>> metricPairs = new ArrayList<>();
RelOptTable table; RelOptTable table;
SolrTable solrTable; SolrTable solrTable;
@ -68,8 +68,8 @@ interface SolrRel extends RelNode {
this.buckets.addAll(buckets); this.buckets.addAll(buckets);
} }
void addMetrics(List<Metric> metrics) { void addMetrics(List<Pair<String, String>> metrics) {
this.metrics.addAll(metrics); this.metricPairs.addAll(metrics);
} }
void setLimit(String limit) { void setLimit(String limit) {

View File

@ -27,13 +27,14 @@ import org.apache.calcite.rel.type.RelProtoDataType;
import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.schema.TranslatableTable; import org.apache.calcite.schema.TranslatableTable;
import org.apache.calcite.schema.impl.AbstractTableQueryable; import org.apache.calcite.schema.impl.AbstractTableQueryable;
import org.apache.calcite.util.Pair;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream; import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.RollupStream; import org.apache.solr.client.solrj.io.stream.RollupStream;
import org.apache.solr.client.solrj.io.stream.StatsStream; import org.apache.solr.client.solrj.io.stream.StatsStream;
import org.apache.solr.client.solrj.io.stream.TupleStream; import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.metrics.Bucket; import org.apache.solr.client.solrj.io.stream.metrics.*;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.common.params.CommonParams; import org.apache.solr.common.params.CommonParams;
import org.apache.solr.update.VersionInfo;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
@ -43,7 +44,7 @@ import java.util.*;
*/ */
class SolrTable extends AbstractQueryableTable implements TranslatableTable { class SolrTable extends AbstractQueryableTable implements TranslatableTable {
private static final String DEFAULT_QUERY = "*:*"; private static final String DEFAULT_QUERY = "*:*";
private static final String DEFAULT_VERSION_FIELD = "_version_"; private static final String DEFAULT_VERSION_FIELD = VersionInfo.VERSION_FIELD;
private static final String DEFAULT_SCORE_FIELD = "score"; private static final String DEFAULT_SCORE_FIELD = "score";
private final String collection; private final String collection;
@ -81,7 +82,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
*/ */
private Enumerable<Object> query(final Properties properties, final List<String> fields, private Enumerable<Object> query(final Properties properties, final List<String> fields,
final String query, final List<String> order, final List<String> buckets, final String query, final List<String> order, final List<String> buckets,
final List<Metric> metrics, final String limit) { final List<Pair<String, String>> metricPairs, final String limit) {
// SolrParams should be a ModifiableParams instead of a map // SolrParams should be a ModifiableParams instead of a map
Map<String, String> solrParams = new HashMap<>(); Map<String, String> solrParams = new HashMap<>();
solrParams.put(CommonParams.OMIT_HEADER, "true"); solrParams.put(CommonParams.OMIT_HEADER, "true");
@ -96,10 +97,20 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
List<String> fieldsList = new ArrayList<>(fields); List<String> fieldsList = new ArrayList<>(fields);
List<String> orderList = new ArrayList<>(order); List<String> orderList = new ArrayList<>(order);
List<Metric> metrics = buildMetrics(metricPairs);
if (!metrics.isEmpty()) { if (!metrics.isEmpty()) {
for(String bucket : buckets) { for(String bucket : buckets) {
orderList.add(bucket + " desc"); orderList.add(bucket + " desc");
} }
for(Metric metric : metrics) {
for(String column : metric.getColumns()) {
if (!fieldsList.contains(column)) {
fieldsList.add(column);
}
}
}
} }
if (orderList.isEmpty()) { if (orderList.isEmpty()) {
@ -174,6 +185,32 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
}; };
} }
private List<Metric> buildMetrics(List<Pair<String, String>> metricPairs) {
List<Metric> metrics = new ArrayList<>(metricPairs.size());
for(Pair<String, String> metricPair : metricPairs) {
metrics.add(getMetric(metricPair));
}
return metrics;
}
private Metric getMetric(Pair<String, String> metricPair) {
switch (metricPair.getKey()) {
case "COUNT":
return new CountMetric(metricPair.getValue());
case "SUM":
case "$SUM0":
return new SumMetric(metricPair.getValue());
case "MIN":
return new MinMetric(metricPair.getValue());
case "MAX":
return new MaxMetric(metricPair.getValue());
case "AVG":
return new MeanMetric(metricPair.getValue());
default:
throw new IllegalArgumentException(metricPair.getKey());
}
}
public <T> Queryable<T> asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) { public <T> Queryable<T> asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) {
return new SolrQueryable<>(queryProvider, schema, this, tableName); return new SolrQueryable<>(queryProvider, schema, this, tableName);
} }
@ -190,7 +227,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
} }
public Enumerator<T> enumerator() { public Enumerator<T> enumerator() {
//noinspection unchecked @SuppressWarnings("unchecked")
final Enumerable<T> enumerable = (Enumerable<T>) getTable().query(getProperties()); final Enumerable<T> enumerable = (Enumerable<T>) getTable().query(getProperties());
return enumerable.enumerator(); return enumerable.enumerator();
} }
@ -209,8 +246,8 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
*/ */
@SuppressWarnings("UnusedDeclaration") @SuppressWarnings("UnusedDeclaration")
public Enumerable<Object> query(List<String> fields, String query, List<String> order, List<String> buckets, public Enumerable<Object> query(List<String> fields, String query, List<String> order, List<String> buckets,
List<Metric> metrics, String limit) { List<Pair<String, String>> metricPairs, String limit) {
return getTable().query(getProperties(), fields, query, order, buckets, metrics, limit); return getTable().query(getProperties(), fields, query, order, buckets, metricPairs, limit);
} }
} }
} }

View File

@ -23,14 +23,13 @@ import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.MethodCallExpression; import org.apache.calcite.linq4j.tree.MethodCallExpression;
import org.apache.calcite.plan.*; import org.apache.calcite.plan.*;
import org.apache.calcite.prepare.CalcitePrepareImpl;
import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterImpl; import org.apache.calcite.rel.convert.ConverterImpl;
import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.runtime.Hook; import org.apache.calcite.runtime.Hook;
import org.apache.calcite.util.BuiltInMethod; import org.apache.calcite.util.BuiltInMethod;
import org.apache.solr.client.solrj.io.stream.metrics.Metric; import org.apache.calcite.util.Pair;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -55,7 +54,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
} }
public Result implement(EnumerableRelImplementor implementor, Prefer pref) { public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
// Generates a call to "query" with the appropriate fields and filterQueries // Generates a call to "query" with the appropriate fields
final BlockBuilder list = new BlockBuilder(); final BlockBuilder list = new BlockBuilder();
final SolrRel.Implementor solrImplementor = new SolrRel.Implementor(); final SolrRel.Implementor solrImplementor = new SolrRel.Implementor();
solrImplementor.visitChild(0, getInput()); solrImplementor.visitChild(0, getInput());
@ -64,17 +63,14 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
final Expression table = list.append("table", solrImplementor.table.getExpression(SolrTable.SolrQueryable.class)); final Expression table = list.append("table", solrImplementor.table.getExpression(SolrTable.SolrQueryable.class));
final Expression fields = list.append("fields", final Expression fields = list.append("fields",
constantArrayList(generateFields(SolrRules.solrFieldNames(rowType), solrImplementor.fieldMappings), String.class)); constantArrayList(generateFields(SolrRules.solrFieldNames(rowType), solrImplementor.fieldMappings), String.class));
final Expression filterQueries = list.append("query", Expressions.constant(solrImplementor.query, String.class)); final Expression query = list.append("query", Expressions.constant(solrImplementor.query, String.class));
final Expression order = list.append("order", constantArrayList(solrImplementor.order, String.class)); final Expression order = list.append("order", constantArrayList(solrImplementor.order, String.class));
final Expression buckets = list.append("buckets", constantArrayList(solrImplementor.buckets, String.class)); final Expression buckets = list.append("buckets", constantArrayList(solrImplementor.buckets, String.class));
final Expression metrics = list.append("metrics", constantArrayList(solrImplementor.metrics, Metric.class)); final Expression metricPairs = list.append("metricPairs", constantArrayList(solrImplementor.metricPairs, Pair.class));
final Expression limit = list.append("limit", Expressions.constant(solrImplementor.limitValue)); final Expression limit = list.append("limit", Expressions.constant(solrImplementor.limitValue));
Expression enumerable = list.append("enumerable", Expressions.call(table, SolrMethod.SOLR_QUERYABLE_QUERY.method, Expression enumerable = list.append("enumerable", Expressions.call(table, SolrMethod.SOLR_QUERYABLE_QUERY.method,
fields, filterQueries, order, buckets, metrics, limit)); fields, query, order, buckets, metricPairs, limit));
if (CalcitePrepareImpl.DEBUG) { Hook.QUERY_PLAN.run(query);
System.out.println("Solr: " + filterQueries);
}
Hook.QUERY_PLAN.run(filterQueries);
list.add(Expressions.return_(null, enumerable)); list.add(Expressions.return_(null, enumerable));
return implementor.result(physType, list.toBlock()); return implementor.result(physType, list.toBlock());
} }

View File

@ -24,40 +24,50 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CountMetric extends Metric { public class CountMetric extends Metric {
private String columnName;
private long count; private long count;
public CountMetric(){ public CountMetric() {
init("count"); this("*");
}
public CountMetric(String columnName) {
init("count", columnName);
} }
public CountMetric(StreamExpression expression, StreamFactory factory) throws IOException{ public CountMetric(StreamExpression expression, StreamFactory factory) throws IOException{
// grab all parameters out // grab all parameters out
String functionName = expression.getFunctionName(); String functionName = expression.getFunctionName();
String columnName = factory.getValueOperand(expression, 0); String columnName = factory.getValueOperand(expression, 0);
// validate expression contains only what we want.
if(!"*".equals(columnName)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(*)", expression, functionName));
}
if(1 != expression.getParameters().size()){ if(1 != expression.getParameters().size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression)); throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression));
} }
init(functionName); init(functionName, columnName);
} }
public String[] getColumns() { public String[] getColumns() {
return new String[0]; if(isAllColumns()) {
return new String[0];
}
return new String[]{columnName};
} }
private void init(String functionName){ private void init(String functionName, String columnName){
this.columnName = columnName;
setFunctionName(functionName); setFunctionName(functionName);
setIdentifier(functionName, "(*)"); setIdentifier(functionName, "(", columnName, ")");
}
private boolean isAllColumns() {
return "*".equals(this.columnName);
} }
public void update(Tuple tuple) { public void update(Tuple tuple) {
++count; if(isAllColumns() || tuple.get(columnName) != null) {
++count;
}
} }
public Long getValue() { public Long getValue() {
@ -65,10 +75,11 @@ public class CountMetric extends Metric {
} }
public Metric newInstance() { public Metric newInstance() {
return new CountMetric(); return new CountMetric(columnName);
} }
@Override @Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return new StreamExpression(getFunctionName()).withParameter("*"); return new StreamExpression(getFunctionName()).withParameter(columnName);
} }
} }