mirror of https://github.com/apache/lucene.git
Cleanup implementation
This commit is contained in:
parent
a414d24684
commit
5daf6c40d8
|
@ -30,7 +30,9 @@ import java.util.Properties;
|
|||
* <p>It accepts connect strings that start with "jdbc:calcitesolr:".</p>
|
||||
*/
|
||||
public class CalciteSolrDriver extends Driver {
|
||||
protected CalciteSolrDriver() {
|
||||
public final static String CONNECT_STRING_PREFIX = "jdbc:calcitesolr:";
|
||||
|
||||
private CalciteSolrDriver() {
|
||||
super();
|
||||
}
|
||||
|
||||
|
@ -38,11 +40,11 @@ public class CalciteSolrDriver extends Driver {
|
|||
new CalciteSolrDriver().register();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getConnectStringPrefix() {
|
||||
return "jdbc:calcitesolr:";
|
||||
return CONNECT_STRING_PREFIX;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Connection connect(String url, Properties info) throws SQLException {
|
||||
Connection connection = super.connect(url, info);
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.calcite.rel.core.AggregateCall;
|
|||
import org.apache.calcite.sql.SqlAggFunction;
|
||||
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
|
||||
import org.apache.calcite.util.ImmutableBitSet;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.*;
|
||||
import org.apache.calcite.util.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -32,6 +32,15 @@ import java.util.*;
|
|||
* Implementation of {@link org.apache.calcite.rel.core.Aggregate} relational expression in Solr.
|
||||
*/
|
||||
class SolrAggregate extends Aggregate implements SolrRel {
|
||||
private static final List<SqlAggFunction> SUPPORTED_AGGREGATIONS = Arrays.asList(
|
||||
SqlStdOperatorTable.COUNT,
|
||||
SqlStdOperatorTable.SUM,
|
||||
SqlStdOperatorTable.SUM0,
|
||||
SqlStdOperatorTable.MIN,
|
||||
SqlStdOperatorTable.MAX,
|
||||
SqlStdOperatorTable.AVG
|
||||
);
|
||||
|
||||
SolrAggregate(
|
||||
RelOptCluster cluster,
|
||||
RelTraitSet traitSet,
|
||||
|
@ -58,12 +67,12 @@ class SolrAggregate extends Aggregate implements SolrRel {
|
|||
final List<String> inNames = SolrRules.solrFieldNames(getInput().getRowType());
|
||||
final List<String> outNames = SolrRules.solrFieldNames(getRowType());
|
||||
|
||||
List<Metric> metrics = new ArrayList<>();
|
||||
List<Pair<String, String>> metrics = new ArrayList<>();
|
||||
Map<String, String> fieldMappings = new HashMap<>();
|
||||
for(AggregateCall aggCall : aggCalls) {
|
||||
Metric metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList());
|
||||
Pair<String, String> metric = toSolrMetric(aggCall.getAggregation(), inNames, aggCall.getArgList());
|
||||
metrics.add(metric);
|
||||
fieldMappings.put(aggCall.getName(), metric.getIdentifier());
|
||||
fieldMappings.put(aggCall.getName(), metric.getKey().toLowerCase(Locale.ROOT) + "(" + metric.getValue() + ")");
|
||||
}
|
||||
|
||||
List<String> buckets = new ArrayList<>();
|
||||
|
@ -78,22 +87,16 @@ class SolrAggregate extends Aggregate implements SolrRel {
|
|||
implementor.addFieldMappings(fieldMappings);
|
||||
}
|
||||
|
||||
private Metric toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) {
|
||||
private Pair<String, String> toSolrMetric(SqlAggFunction aggregation, List<String> inNames, List<Integer> args) {
|
||||
switch (args.size()) {
|
||||
case 0:
|
||||
if(aggregation.equals(SqlStdOperatorTable.COUNT)) {
|
||||
return new CountMetric();
|
||||
if (aggregation.equals(SqlStdOperatorTable.COUNT)) {
|
||||
return new Pair<>(aggregation.getName(), "*");
|
||||
}
|
||||
case 1:
|
||||
final String inName = inNames.get(args.get(0));
|
||||
if (aggregation.equals(SqlStdOperatorTable.SUM) || aggregation.equals(SqlStdOperatorTable.SUM0)) {
|
||||
return new SumMetric(inName);
|
||||
} else if (aggregation.equals(SqlStdOperatorTable.MIN)) {
|
||||
return new MinMetric(inName);
|
||||
} else if (aggregation.equals(SqlStdOperatorTable.MAX)) {
|
||||
return new MaxMetric(inName);
|
||||
} else if (aggregation.equals(SqlStdOperatorTable.AVG)) {
|
||||
return new MeanMetric(inName);
|
||||
if(SUPPORTED_AGGREGATIONS.contains(aggregation)) {
|
||||
return new Pair<>(aggregation.getName(), inName);
|
||||
}
|
||||
default:
|
||||
throw new AssertionError("Invalid aggregation " + aggregation + " with args " + args + " with names" + inNames);
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.solr.handler.sql;
|
|||
import org.apache.calcite.plan.Convention;
|
||||
import org.apache.calcite.plan.RelOptTable;
|
||||
import org.apache.calcite.rel.RelNode;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
|
||||
import org.apache.calcite.util.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
|
@ -42,7 +42,7 @@ interface SolrRel extends RelNode {
|
|||
String limitValue = null;
|
||||
final List<String> order = new ArrayList<>();
|
||||
final List<String> buckets = new ArrayList<>();
|
||||
final List<Metric> metrics = new ArrayList<>();
|
||||
final List<Pair<String, String>> metricPairs = new ArrayList<>();
|
||||
|
||||
RelOptTable table;
|
||||
SolrTable solrTable;
|
||||
|
@ -68,8 +68,8 @@ interface SolrRel extends RelNode {
|
|||
this.buckets.addAll(buckets);
|
||||
}
|
||||
|
||||
void addMetrics(List<Metric> metrics) {
|
||||
this.metrics.addAll(metrics);
|
||||
void addMetrics(List<Pair<String, String>> metrics) {
|
||||
this.metricPairs.addAll(metrics);
|
||||
}
|
||||
|
||||
void setLimit(String limit) {
|
||||
|
|
|
@ -27,13 +27,14 @@ import org.apache.calcite.rel.type.RelProtoDataType;
|
|||
import org.apache.calcite.schema.SchemaPlus;
|
||||
import org.apache.calcite.schema.TranslatableTable;
|
||||
import org.apache.calcite.schema.impl.AbstractTableQueryable;
|
||||
import org.apache.calcite.util.Pair;
|
||||
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
|
||||
import org.apache.solr.client.solrj.io.stream.RollupStream;
|
||||
import org.apache.solr.client.solrj.io.stream.StatsStream;
|
||||
import org.apache.solr.client.solrj.io.stream.TupleStream;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.Bucket;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.*;
|
||||
import org.apache.solr.common.params.CommonParams;
|
||||
import org.apache.solr.update.VersionInfo;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
@ -43,7 +44,7 @@ import java.util.*;
|
|||
*/
|
||||
class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
||||
private static final String DEFAULT_QUERY = "*:*";
|
||||
private static final String DEFAULT_VERSION_FIELD = "_version_";
|
||||
private static final String DEFAULT_VERSION_FIELD = VersionInfo.VERSION_FIELD;
|
||||
private static final String DEFAULT_SCORE_FIELD = "score";
|
||||
|
||||
private final String collection;
|
||||
|
@ -81,7 +82,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
|||
*/
|
||||
private Enumerable<Object> query(final Properties properties, final List<String> fields,
|
||||
final String query, final List<String> order, final List<String> buckets,
|
||||
final List<Metric> metrics, final String limit) {
|
||||
final List<Pair<String, String>> metricPairs, final String limit) {
|
||||
// SolrParams should be a ModifiableParams instead of a map
|
||||
Map<String, String> solrParams = new HashMap<>();
|
||||
solrParams.put(CommonParams.OMIT_HEADER, "true");
|
||||
|
@ -96,10 +97,20 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
|||
List<String> fieldsList = new ArrayList<>(fields);
|
||||
List<String> orderList = new ArrayList<>(order);
|
||||
|
||||
List<Metric> metrics = buildMetrics(metricPairs);
|
||||
|
||||
if (!metrics.isEmpty()) {
|
||||
for(String bucket : buckets) {
|
||||
orderList.add(bucket + " desc");
|
||||
}
|
||||
|
||||
for(Metric metric : metrics) {
|
||||
for(String column : metric.getColumns()) {
|
||||
if (!fieldsList.contains(column)) {
|
||||
fieldsList.add(column);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (orderList.isEmpty()) {
|
||||
|
@ -174,6 +185,32 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
|||
};
|
||||
}
|
||||
|
||||
private List<Metric> buildMetrics(List<Pair<String, String>> metricPairs) {
|
||||
List<Metric> metrics = new ArrayList<>(metricPairs.size());
|
||||
for(Pair<String, String> metricPair : metricPairs) {
|
||||
metrics.add(getMetric(metricPair));
|
||||
}
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Metric getMetric(Pair<String, String> metricPair) {
|
||||
switch (metricPair.getKey()) {
|
||||
case "COUNT":
|
||||
return new CountMetric(metricPair.getValue());
|
||||
case "SUM":
|
||||
case "$SUM0":
|
||||
return new SumMetric(metricPair.getValue());
|
||||
case "MIN":
|
||||
return new MinMetric(metricPair.getValue());
|
||||
case "MAX":
|
||||
return new MaxMetric(metricPair.getValue());
|
||||
case "AVG":
|
||||
return new MeanMetric(metricPair.getValue());
|
||||
default:
|
||||
throw new IllegalArgumentException(metricPair.getKey());
|
||||
}
|
||||
}
|
||||
|
||||
public <T> Queryable<T> asQueryable(QueryProvider queryProvider, SchemaPlus schema, String tableName) {
|
||||
return new SolrQueryable<>(queryProvider, schema, this, tableName);
|
||||
}
|
||||
|
@ -190,7 +227,7 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
|||
}
|
||||
|
||||
public Enumerator<T> enumerator() {
|
||||
//noinspection unchecked
|
||||
@SuppressWarnings("unchecked")
|
||||
final Enumerable<T> enumerable = (Enumerable<T>) getTable().query(getProperties());
|
||||
return enumerable.enumerator();
|
||||
}
|
||||
|
@ -209,8 +246,8 @@ class SolrTable extends AbstractQueryableTable implements TranslatableTable {
|
|||
*/
|
||||
@SuppressWarnings("UnusedDeclaration")
|
||||
public Enumerable<Object> query(List<String> fields, String query, List<String> order, List<String> buckets,
|
||||
List<Metric> metrics, String limit) {
|
||||
return getTable().query(getProperties(), fields, query, order, buckets, metrics, limit);
|
||||
List<Pair<String, String>> metricPairs, String limit) {
|
||||
return getTable().query(getProperties(), fields, query, order, buckets, metricPairs, limit);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,14 +23,13 @@ import org.apache.calcite.linq4j.tree.Expression;
|
|||
import org.apache.calcite.linq4j.tree.Expressions;
|
||||
import org.apache.calcite.linq4j.tree.MethodCallExpression;
|
||||
import org.apache.calcite.plan.*;
|
||||
import org.apache.calcite.prepare.CalcitePrepareImpl;
|
||||
import org.apache.calcite.rel.RelNode;
|
||||
import org.apache.calcite.rel.convert.ConverterImpl;
|
||||
import org.apache.calcite.rel.metadata.RelMetadataQuery;
|
||||
import org.apache.calcite.rel.type.RelDataType;
|
||||
import org.apache.calcite.runtime.Hook;
|
||||
import org.apache.calcite.util.BuiltInMethod;
|
||||
import org.apache.solr.client.solrj.io.stream.metrics.Metric;
|
||||
import org.apache.calcite.util.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -55,7 +54,7 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
|
|||
}
|
||||
|
||||
public Result implement(EnumerableRelImplementor implementor, Prefer pref) {
|
||||
// Generates a call to "query" with the appropriate fields and filterQueries
|
||||
// Generates a call to "query" with the appropriate fields
|
||||
final BlockBuilder list = new BlockBuilder();
|
||||
final SolrRel.Implementor solrImplementor = new SolrRel.Implementor();
|
||||
solrImplementor.visitChild(0, getInput());
|
||||
|
@ -64,17 +63,14 @@ class SolrToEnumerableConverter extends ConverterImpl implements EnumerableRel {
|
|||
final Expression table = list.append("table", solrImplementor.table.getExpression(SolrTable.SolrQueryable.class));
|
||||
final Expression fields = list.append("fields",
|
||||
constantArrayList(generateFields(SolrRules.solrFieldNames(rowType), solrImplementor.fieldMappings), String.class));
|
||||
final Expression filterQueries = list.append("query", Expressions.constant(solrImplementor.query, String.class));
|
||||
final Expression query = list.append("query", Expressions.constant(solrImplementor.query, String.class));
|
||||
final Expression order = list.append("order", constantArrayList(solrImplementor.order, String.class));
|
||||
final Expression buckets = list.append("buckets", constantArrayList(solrImplementor.buckets, String.class));
|
||||
final Expression metrics = list.append("metrics", constantArrayList(solrImplementor.metrics, Metric.class));
|
||||
final Expression metricPairs = list.append("metricPairs", constantArrayList(solrImplementor.metricPairs, Pair.class));
|
||||
final Expression limit = list.append("limit", Expressions.constant(solrImplementor.limitValue));
|
||||
Expression enumerable = list.append("enumerable", Expressions.call(table, SolrMethod.SOLR_QUERYABLE_QUERY.method,
|
||||
fields, filterQueries, order, buckets, metrics, limit));
|
||||
if (CalcitePrepareImpl.DEBUG) {
|
||||
System.out.println("Solr: " + filterQueries);
|
||||
}
|
||||
Hook.QUERY_PLAN.run(filterQueries);
|
||||
fields, query, order, buckets, metricPairs, limit));
|
||||
Hook.QUERY_PLAN.run(query);
|
||||
list.add(Expressions.return_(null, enumerable));
|
||||
return implementor.result(physType, list.toBlock());
|
||||
}
|
||||
|
|
|
@ -24,40 +24,50 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
|
|||
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
|
||||
|
||||
public class CountMetric extends Metric {
|
||||
private String columnName;
|
||||
private long count;
|
||||
|
||||
public CountMetric(){
|
||||
init("count");
|
||||
|
||||
public CountMetric() {
|
||||
this("*");
|
||||
}
|
||||
|
||||
public CountMetric(String columnName) {
|
||||
init("count", columnName);
|
||||
}
|
||||
|
||||
public CountMetric(StreamExpression expression, StreamFactory factory) throws IOException{
|
||||
// grab all parameters out
|
||||
String functionName = expression.getFunctionName();
|
||||
String columnName = factory.getValueOperand(expression, 0);
|
||||
|
||||
// validate expression contains only what we want.
|
||||
if(!"*".equals(columnName)){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expected %s(*)", expression, functionName));
|
||||
}
|
||||
|
||||
if(1 != expression.getParameters().size()){
|
||||
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found", expression));
|
||||
}
|
||||
|
||||
init(functionName);
|
||||
|
||||
|
||||
init(functionName, columnName);
|
||||
}
|
||||
|
||||
public String[] getColumns() {
|
||||
return new String[0];
|
||||
if(isAllColumns()) {
|
||||
return new String[0];
|
||||
}
|
||||
return new String[]{columnName};
|
||||
}
|
||||
|
||||
private void init(String functionName){
|
||||
|
||||
private void init(String functionName, String columnName){
|
||||
this.columnName = columnName;
|
||||
setFunctionName(functionName);
|
||||
setIdentifier(functionName, "(*)");
|
||||
setIdentifier(functionName, "(", columnName, ")");
|
||||
}
|
||||
|
||||
private boolean isAllColumns() {
|
||||
return "*".equals(this.columnName);
|
||||
}
|
||||
|
||||
public void update(Tuple tuple) {
|
||||
++count;
|
||||
if(isAllColumns() || tuple.get(columnName) != null) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
|
||||
public Long getValue() {
|
||||
|
@ -65,10 +75,11 @@ public class CountMetric extends Metric {
|
|||
}
|
||||
|
||||
public Metric newInstance() {
|
||||
return new CountMetric();
|
||||
return new CountMetric(columnName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
|
||||
return new StreamExpression(getFunctionName()).withParameter("*");
|
||||
return new StreamExpression(getFunctionName()).withParameter(columnName);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue