Cleanup implementation

This commit is contained in:
Kevin Risden 2016-05-04 14:56:23 -05:00
parent a414d24684
commit 5daf6c40d8
6 changed files with 106 additions and 57 deletions

View File

@ -30,7 +30,9 @@ import java.util.Properties;
* <p>It accepts connect strings that start with "jdbc:calcitesolr:".</p>
*/
public class CalciteSolrDriver extends Driver {
protected CalciteSolrDriver() {
public final static String CONNECT_STRING_PREFIX = "jdbc:calcitesolr:";
private CalciteSolrDriver() {
super();
}
@ -38,11 +40,11 @@ public class CalciteSolrDriver extends Driver {
new CalciteSolrDriver().register();
}
@Override
protected String getConnectStringPrefix() {
return "jdbc:calcitesolr:";
return CONNECT_STRING_PREFIX;
}
@Override
public Connection connect(String url, Properties info) throws SQLException {
Connection connection = super.connect(url, info);

View File

@ -24,7 +24,7 @@ import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.solr.client.solrj.io.stream.metrics.*;
import org.apache.calcite.util.Pair;
import java.util.*;
@ -32,6 +32,15 @@ import java.util.*;
* Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr.
*/
class SolrAggregate extends Aggregate implements SolrRel {
private static final List<SqlAggFunction> SUPPORTED_AGGREGATIONS = Arrays.asList(
SqlStdOperatorTable.COUNT,
SqlStdOperatorTable.SUM,
SqlStdOperatorTable.SUM0,
SqlStdOperatorTable.MIN,
SqlStdOperatorTable.MAX,
SqlStdOperatorTable.AVG
);
SolrAggregate(
RelOptCluster cluster,
RelTraitSet traitSet,
@ -58,12 +67,12 @@ class SolrAggregate extends Aggregate implements SolrRel {
final List<String> inNames = SolrRules.solrFieldNames(getInput().getRowType());
final List<String> outNames = SolrRules.solrFieldNames(getRowType());
List<Metric> metrics = new ArrayList<>();
List<Pair<String, String>> metrics = new ArrayList<>();
Map<String, String> fieldMappings = new HashMap<>();
for(AggregateCall aggCall : aggCalls) {
Metric metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList());
Pair<String, String> metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList());
metrics.add(metric);
fieldMappings.put(aggCall.getName(), metric.getIdentifier());
fieldMappings.put(aggCall.getName(), metric.getKey().toLowerCase(Locale.ROOT) + "(" + metric.getValue() + ")");
}
List<String> buckets = new ArrayList<>();
@ -78,22 +87,16 @@ class SolrAggregate extends Aggregate implements SolrRel {
implementor.addFieldMappings(fieldMappings);
}
private Metric toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) {
private Pair<String, String> toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) {
switch (args.size()) {
case 0:
if (aggregation.equals(SqlStdOperatorTable.COUNT)) {
return new CountMetric();
return new Pair<>(aggregation.getName(), "*");
}
case 1:
final String inName = inNames.get(args.get(0));
if (aggregation.equals(SqlStdOperatorTable.SUM) || aggregation.equals(SqlStdOperatorTable.SUM0)) {
return new SumMetric(inName);
} else if (aggregation.equals(SqlStdOperatorTable.MIN)) {
return new MinMetric(inName);
} else if (aggregation.equals(SqlStdOperatorTable.MAX)) {
return new MaxMetric(inName);
} else if (aggregation.equals(SqlStdOperatorTable.AVG)) {
return new MeanMetric(inName);
if(SUPPORTED_AGGREGATIONS.contains(aggregation)) {
return new Pair<>(aggregation.getName(), inName);
}
default:
throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames);

View File

@ -19,7 +19,7 @@ package org.apache.solr.handler.sql;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.rel.RelNode;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.calcite.util.Pair;
import java.util.ArrayList;
import java.util.HashMap;
@ -42,7 +42,7 @@ interface SolrRel extends RelNode {
String limitValue = null;
final List<String> order = new ArrayList<>();
final List<String> buckets = new ArrayList<>();
final List<Metric> metrics = new ArrayList<>();
final List<Pair<String, String>> metricPairs = new ArrayList<>();
RelOptTable table;
SolrTable solrTable;
@ -68,8 +68,8 @@ interface SolrRel extends RelNode {
this.buckets.addAll(buckets);
}
void addMetrics(List<Metric> metrics) {
this.metrics.addAll(metrics);
void addMetrics(List<Pair<String, String>> metrics) {
this.metricPairs.addAll(metrics);
}
void setLimit(String limit) {

View File

@ -27,13 +27,14 @@ import org.apache.calcite.rel.type.RelProtoDataType;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.schema.TranslatableTable;
import org.apache.calcite.schema.impl.AbstractTableQueryable;
import org.apache.calcite.util.Pair;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.RollupStream;
import org.apache.solr.client.solrj.io.stream.StatsStream;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.metrics.Bucket;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.client.solrj.io.stream.metrics.*;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.update.VersionInfo;
import java.io.IOException;
import java.util.*;
@ -43,7 +44,7 @@ import java.util.*;
*/
class SolrTable extends AbstractQueryableTable implements TranslatableTable {
private static final String DEFAULT_QUERY = "*:*";
private static final String DEFAULT_VERSION_FIELD = "_version_";
private static final String DEFAULT_VERSION_FIELD = VersionInfo.VERSION_FIELD;
private static final String DEFAULT_SCORE_FIELD = "score";
private final String collection;
@ -81,7 +82,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
*/
private Enumerable<Object> query(final Properties properties, final List<String> fields,
final String query, final List<String> order, final List<String> buckets,
final List<Metric> metrics, final String limit) {
final List<Pair<String, String>> metricPairs, final String limit) {
// SolrParams should be a ModifiableParams instead of a map
Map<String, String> solrParams = new HashMap<>();
solrParams.put(CommonParams.OMIT_HEADER, "true");
@ -96,10 +97,20 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
List<String> fieldsList = new ArrayList<>(fields);
List<String> orderList = new ArrayList<>(order);
List<Metric> metrics = buildMetrics(metricPairs);
if (!metrics.isEmpty()) {
for(String bucket : buckets) {
orderList.add(bucket + " desc");
}
for(Metric metric : metrics) {
for(String column : metric.getColumns()) {
if (!fieldsList.contains(column)) {
fieldsList.add(column);
}
}
}
}
if (orderList.isEmpty()) {
@ -174,6 +185,32 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
};
}
private List<Metric> buildMetrics(List<Pair<String, String>> metricPairs) {
List<Metric> metrics = new ArrayList<>(metricPairs.size());
for(Pair<String, String> metricPair : metricPairs) {
metrics.add(getMetric(metricPair));
}
return metrics;
}
private Metric getMetric(Pair<String, String> metricPair) {
switch (metricPair.getKey()) {
case "COUNT":
return new CountMetric(metricPair.getValue());
case "SUM":
case "$SUM0":
return new SumMetric(metricPair.getValue());
case "MIN":
return new MinMetric(metricPair.getValue());
case "MAX":
return new MaxMetric(metricPair.getValue());
case "AVG":
return new MeanMetric(metricPair.getValue());
default:
throw new IllegalArgumentException(metricPair.getKey());
}
}
public <T> Queryable<T> asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) {
return new SolrQueryable<>(queryProvider, schema, this, tableName);
}
@ -190,7 +227,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
}
public Enumerator<T> enumerator() {
//noinspection unchecked
@SuppressWarnings("unchecked")
final Enumerable<T> enumerable = (Enumerable<T>) getTable().query(getProperties());
return enumerable.enumerator();
}
@ -209,8 +246,8 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
*/
@SuppressWarnings("UnusedDeclaration")
public Enumerable<Object> query(List<String> fields, String query, List<String> order, List<String> buckets,
List<Metric> metrics, String limit) {
return getTable().query(getProperties(), fields, query, order, buckets, metrics, limit);
List<Pair<String, String>> metricPairs, String limit) {
return getTable().query(getProperties(), fields, query, order, buckets, metricPairs, limit);
}
}
}

View File

@ -23,14 +23,13 @@ import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.MethodCallExpression;
import org.apache.calcite.plan.*;
import org.apache.calcite.prepare.CalcitePrepareImpl;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterImpl;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.runtime.Hook;
import org.apache.calcite.util.BuiltInMethod;
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.calcite.util.Pair;
import java.util.ArrayList;
import java.util.List;
@ -55,7 +54,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
}
public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
// Generates a call to "query" with the appropriate fields and filterQueries
// Generates a call to "query" with the appropriate fields
final BlockBuilder list = new BlockBuilder();
final SolrRel.Implementor solrImplementor = new SolrRel.Implementor();
solrImplementor.visitChild(0, getInput());
@ -64,17 +63,14 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
final Expression table = list.append("table", solrImplementor.table.getExpression(SolrTable.SolrQueryable.class));
final Expression fields = list.append("fields",
constantArrayList(generateFields(SolrRules.solrFieldNames(rowType), solrImplementor.fieldMappings), String.class));
final Expression filterQueries = list.append("query", Expressions.constant(solrImplementor.query, String.class));
final Expression query = list.append("query", Expressions.constant(solrImplementor.query, String.class));
final Expression order = list.append("order", constantArrayList(solrImplementor.order, String.class));
final Expression buckets = list.append("buckets", constantArrayList(solrImplementor.buckets, String.class));
final Expression metrics = list.append("metrics", constantArrayList(solrImplementor.metrics, Metric.class));
final Expression metricPairs = list.append("metricPairs", constantArrayList(solrImplementor.metricPairs, Pair.class));
final Expression limit = list.append("limit", Expressions.constant(solrImplementor.limitValue));
Expression enumerable = list.append("enumerable", Expressions.call(table, SolrMethod.SOLR_QUERYABLE_QUERY.method,
fields, filterQueries, order, buckets, metrics, limit));
if (CalcitePrepareImpl.DEBUG) {
System.out.println("Solr: " + filterQueries);
}
Hook.QUERY_PLAN.run(filterQueries);
fields, query, order, buckets, metricPairs, limit));
Hook.QUERY_PLAN.run(query);
list.add(Expressions.return_(null, enumerable));
return implementor.result(physType, list.toBlock());
}

View File

@ -24,10 +24,15 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public class CountMetric extends Metric {
private String columnName;
private long count;
public CountMetric() {
init("count");
this("*");
}
public CountMetric(String columnName) {
init("count", columnName);
}
public CountMetric(StreamExpression expression, StreamFactory factory) throws IOException{
@ -35,40 +40,46 @@ public class CountMetric extends Metric {
String functionName = expression.getFunctionName();
String columnName = factory.getValueOperand(expression, 0);
// validate expression contains only what we want.
if(!"*".equals(columnName)){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(*)", expression, functionName));
}
if(1 != expression.getParameters().size()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression));
}
init(functionName);
init(functionName, columnName);
}
public String[] getColumns() {
if(isAllColumns()) {
return new String[0];
}
return new String[]{columnName};
}
private void init(String functionName){
private void init(String functionName, String columnName){
this.columnName = columnName;
setFunctionName(functionName);
setIdentifier(functionName, "(*)");
setIdentifier(functionName, "(", columnName, ")");
}
private boolean isAllColumns() {
return "*".equals(this.columnName);
}
public void update(Tuple tuple) {
if(isAllColumns() || tuple.get(columnName) != null) {
++count;
}
}
public Long getValue() {
return count;
}
public Metric newInstance() {
return new CountMetric();
return new CountMetric(columnName);
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
return new StreamExpression(getFunctionName()).withParameter("*");
return new StreamExpression(getFunctionName()).withParameter(columnName);
}
}