Add support for percentile and percentile ranks aggs

Original commit: elastic/x-pack-elasticsearch@dc443ed465
This commit is contained in:
Costin Leau 2017-08-05 20:35:33 +03:00
parent 335838db08
commit 2c830bdff2
44 changed files with 1326 additions and 642 deletions

View File

@ -20,10 +20,8 @@ import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.jdbc.framework.JdbcAssert.assertResultSets;
/**
@ -50,7 +48,9 @@ public class CsvSpecIT extends SpecBaseIntegrationTestCase {
CsvSpecParser parser = new CsvSpecParser();
return CollectionUtils.combine(
readScriptSpec("/command.csv-spec", parser),
readScriptSpec("/fulltext.csv-spec", parser));
readScriptSpec("/fulltext.csv-spec", parser),
readScriptSpec("/agg.csv-spec", parser)
);
}
public CsvSpecIT(String groupName, String testName, Integer lineNumber, Path source, CsvTestCase testCase) {
@ -62,7 +62,7 @@ public class CsvSpecIT extends SpecBaseIntegrationTestCase {
try {
assertMatchesCsv(testCase.query, testName, testCase.expectedResults);
} catch (AssertionError ae) {
throw reworkException(new AssertionError(errorMessage(ae), ae.getCause()));
throw reworkException(ae);
}
}
@ -86,22 +86,20 @@ public class CsvSpecIT extends SpecBaseIntegrationTestCase {
.executeQuery("SELECT * FROM " + csvTableName);
// trigger data loading for type inference
expected.beforeFirst();
Statement statement = es.createStatement();
//statement.setFetchSize(randomInt(10));
// NOCOMMIT: hook up pagination
// NOCOMMIT sometimes accept the default fetch size. I believe it is 0 now which breaks things.
statement.setFetchSize(1000);
ResultSet actual = statement.executeQuery(query);
ResultSet actual = executeJdbcQuery(es, query);
assertResultSets(expected, actual);
}
}
String errorMessage(Throwable th) {
return format(Locale.ROOT, "test%s@%s:%d failed\n\"%s\"\n%s", testName, source.getFileName().toString(), lineNumber,
testCase.query, th.getMessage());
private ResultSet executeJdbcQuery(Connection con, String query) throws SQLException {
Statement statement = con.createStatement();
//statement.setFetchSize(randomInt(10));
// NOCOMMIT: hook up pagination
statement.setFetchSize(1000);
return statement.executeQuery(query);
}
private static class CsvSpecParser implements Parser {
protected static class CsvSpecParser implements Parser {
private final StringBuilder data = new StringBuilder();
private CsvTestCase testCase;
@ -137,7 +135,7 @@ public class CsvSpecIT extends SpecBaseIntegrationTestCase {
}
}
private static class CsvTestCase {
protected static class CsvTestCase {
String query;
String expectedResults;
}

View File

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.elasticsearch.xpack.sql.jdbc.framework.JdbcTestUtils;
import java.nio.file.Path;
import java.util.List;
public abstract class DebugCsvSpec extends CsvSpecIT {
@ParametersFactory(shuffle = false, argumentFormatting = SqlSpecIT.PARAM_FORMATTING) // NOCOMMIT are we sure?!
public static List<Object[]> readScriptSpec() throws Exception {
JdbcTestUtils.sqlLogging();
CsvSpecParser parser = new CsvSpecParser();
return readScriptSpec("/debug.csv-spec", parser);
}
public DebugCsvSpec(String groupName, String testName, Integer lineNumber, Path source, CsvTestCase testCase) {
super(groupName, testName, lineNumber, source, testCase);
}
// @Override
// public void assertResults(ResultSet expected, ResultSet actual) throws SQLException {
// Logger logger = Loggers.getLogger("org.elasticsearch.xpack.sql.test");
// Loggers.setLevel(logger, "INFO");
//
// JdbcTestUtils.resultSetToLogger(logger, actual);
// }
}

View File

@ -39,7 +39,7 @@ public class SqlSpecIT extends SpecBaseIntegrationTestCase {
// example for enabling logging
//JdbcTestUtils.sqlLogging();
SqlSpecParser parser = new SqlSpecParser();
Parser parser = parser();
return CollectionUtils.combine(
readScriptSpec("/select.sql-spec", parser),
readScriptSpec("/filter.sql-spec", parser),
@ -63,6 +63,10 @@ public class SqlSpecIT extends SpecBaseIntegrationTestCase {
}
}
static SqlSpecParser parser() {
return new SqlSpecParser();
}
public SqlSpecIT(String groupName, String testName, Integer lineNumber, Path source, String query) {
super(groupName, testName, lineNumber, source);
this.query = query;
@ -74,17 +78,17 @@ public class SqlSpecIT extends SpecBaseIntegrationTestCase {
Connection es = esJdbc()) {
ResultSet expected, actual;
try {
expected = executeQuery(h2);
actual = executeQuery(es);
expected = executeJdbcQuery(h2);
actual = executeJdbcQuery(es);
assertResultSets(expected, actual);
} catch (AssertionError ae) {
throw reworkException(new AssertionError(errorMessage(ae), ae.getCause()));
throw reworkException(ae);
}
}
}
private ResultSet executeQuery(Connection con) throws SQLException {
private ResultSet executeJdbcQuery(Connection con) throws SQLException {
Statement statement = con.createStatement();
//statement.setFetchSize(randomInt(10));
// NOCOMMIT: hook up pagination

View File

@ -10,10 +10,13 @@ import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.List;
import java.util.Locale;
import java.util.TimeZone;
import static java.lang.String.format;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@ -29,7 +32,22 @@ public class JdbcAssert {
ResultSetMetaData expectedMeta = expected.getMetaData();
ResultSetMetaData actualMeta = actual.getMetaData();
assertEquals("Different number of columns returned", expectedMeta.getColumnCount(), actualMeta.getColumnCount());
if (expectedMeta.getColumnCount() != actualMeta.getColumnCount()) {
List<String> expectedCols = new ArrayList<>();
for (int i = 1; i <= expectedMeta.getColumnCount(); i++) {
expectedCols.add(expectedMeta.getColumnName(i));
}
List<String> actualCols = new ArrayList<>();
for (int i = 1; i <= actualMeta.getColumnCount(); i++) {
actualCols.add(actualMeta.getColumnName(i));
}
assertEquals(format(Locale.ROOT, "Different number of columns returned (expected %d but was %d);",
expectedMeta.getColumnCount(), actualMeta.getColumnCount()),
expectedCols.toString(), actualCols.toString());
}
for (int column = 1; column <= expectedMeta.getColumnCount(); column++) {
String expectedName = expectedMeta.getColumnName(column);

View File

@ -5,9 +5,13 @@
*/
package org.elasticsearch.xpack.sql.jdbc.framework;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Map;
import java.util.Map.Entry;
@ -28,4 +32,84 @@ public abstract class JdbcTestUtils {
Loggers.setLevel(Loggers.getLogger(entry.getKey()), entry.getValue());
}
}
public static void printResultSet(ResultSet set) throws Exception {
Logger logger = Loggers.getLogger("org.elasticsearch.xpack.sql.test");
Loggers.setLevel(logger, "INFO");
ResultSetMetaData metaData = set.getMetaData();
// header
StringBuilder sb = new StringBuilder();
int colSize = 15;
for (int column = 1; column <= metaData.getColumnCount(); column++) {
String colName = metaData.getColumnName(column);
int size = colName.length();
if (column > 1) {
sb.append("|");
size++;
}
sb.append(colName);
for (int i = size; i < colSize; i++) {
sb.append(" ");
}
}
logger.info(sb.toString());
}
private static final int MAX_WIDTH = 20;
public static void resultSetToLogger(Logger log, ResultSet rs) throws SQLException {
ResultSetMetaData metaData = rs.getMetaData();
StringBuilder sb = new StringBuilder();
StringBuilder column = new StringBuilder();
int columns = metaData.getColumnCount();
for (int i = 1; i <= columns; i++) {
if (i > 1) {
sb.append(" | ");
}
column.setLength(0);
column.append(metaData.getColumnName(i));
column.append("(");
column.append(metaData.getColumnTypeName(i));
column.append(")");
sb.append(trimOrPad(column));
}
int l = sb.length();
sb.append("\n");
for (int i = 0; i < l; i++) {
sb.append("=");
}
log.info(sb);
while (rs.next()) {
sb.setLength(0);
for (int i = 1; i <= columns; i++) {
column.setLength(0);
if (i > 1) {
sb.append(" | ");
}
sb.append(trimOrPad(column.append(rs.getString(i))));
}
log.info(sb);
}
}
private static StringBuilder trimOrPad(StringBuilder buffer) {
if (buffer.length() > MAX_WIDTH) {
buffer.setLength(MAX_WIDTH - 1);
buffer.append("~");
}
else {
for (int i = buffer.length(); i < MAX_WIDTH; i++) {
buffer.append(" ");
}
}
return buffer;
}
}

View File

@ -0,0 +1,67 @@
//
// Aggs not supported by H2 / traditional SQL stores
//
singlePercentileWithoutComma
SELECT gender, PERCENTILE(emp_no, 97) p1 FROM test_emp GROUP BY gender;
gender | p1
M | 10095.6112
F | 10099.1936
;
singlePercentileWithComma
SELECT gender, PERCENTILE(emp_no, 97.76) p1 FROM test_emp GROUP BY gender;
gender | p1
M | 10095.6112
F | 10099.1936
;
multiplePercentilesOneWithCommaOneWithout
SELECT gender, PERCENTILE(emp_no, 92.45) p1, PERCENTILE(emp_no, 91) p2 FROM test_emp GROUP BY gender;
gender | p1 | p2
M | 10090.319 | 10087.68
F | 10095.128 | 10093.52
;
multiplePercentilesWithoutComma
SELECT gender, PERCENTILE(emp_no, 91) p1, PERCENTILE(emp_no, 89) p2 FROM test_emp GROUP BY gender;
gender | p1 | p2
M | 10087.68 | 10085.18
F | 10093.52 | 10092.08
;
multiplePercentilesWithComma
SELECT gender, PERCENTILE(emp_no, 85.7) p1, PERCENTILE(emp_no, 94.3) p2 FROM test_emp GROUP BY gender;
gender | p1 | p2
M | 10083.134 | 10091.932
F | 10088.852 | 10097.792
;
percentileRank
SELECT gender, PERCENTILE_RANK(emp_no, 10025) rank FROM test_emp GROUP BY gender;
gender | rank
M | 23.41269841269841
F | 26.351351351351347
;
multiplePercentileRanks
SELECT gender, PERCENTILE_RANK(emp_no, 10030.0) rank1, PERCENTILE_RANK(emp_no, 10025) rank2 FROM test_emp GROUP BY gender;
gender | rank1 | rank2
M | 29.365079365079367 | 23.41269841269841
F | 29.93762993762994 | 26.351351351351347
;
multiplePercentilesAndPercentileRank
SELECT gender, PERCENTILE(emp_no, 97.76) p1, PERCENTILE(emp_no, 93.3) p2, PERCENTILE_RANK(emp_no, 10025) rank FROM test_emp GROUP BY gender;
gender | p1 | p2 | rank
M | 10095.6112 | 10090.846 | 23.41269841269841
F | 10099.1936 | 10096.351999999999 | 26.351351351351347
;

View File

@ -3,51 +3,9 @@
//
debug
SHOW FUNCTIONS;
SELECT gender, PERCENTILE(emp_no, 97.76) p1, PERCENTILE(emp_no, 93.3) p2, PERCENTILE_RANK(emp_no, 10025) rank FROM test_emp GROUP BY gender;
name | type
AVG |AGGREGATE
COUNT |AGGREGATE
MAX |AGGREGATE
MIN |AGGREGATE
SUM |AGGREGATE
DAY_OF_MONTH |SCALAR
DAY |SCALAR
DOM |SCALAR
DAY_OF_WEEK |SCALAR
DOW |SCALAR
DAY_OF_YEAR |SCALAR
DOY |SCALAR
HOUR_OF_DAY |SCALAR
HOUR |SCALAR
MINUTE_OF_DAY |SCALAR
MINUTE_OF_HOUR |SCALAR
MINUTE |SCALAR
SECOND_OF_MINUTE|SCALAR
SECOND |SCALAR
MONTH_OF_YEAR |SCALAR
MONTH |SCALAR
YEAR |SCALAR
ABS |SCALAR
ACOS |SCALAR
ASIN |SCALAR
ATAN |SCALAR
CBRT |SCALAR
CEIL |SCALAR
COS |SCALAR
COSH |SCALAR
DEGREES |SCALAR
E |SCALAR
EXP |SCALAR
EXPM1 |SCALAR
FLOOR |SCALAR
LOG |SCALAR
LOG10 |SCALAR
PI |SCALAR
RADIANS |SCALAR
ROUND |SCALAR
SIN |SCALAR
SINH |SCALAR
SQRT |SCALAR
TAN |SCALAR
gender | p1 | p2 | rank
M | 10095.6112 | 10090.846 | 23.41269841269841
F | 10099.1936 | 10096.351999999999 | 26.351351351351347
;

View File

@ -823,7 +823,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
if (f instanceof Count) {
Count c = (Count) f;
if (!c.distinct()) {
if (c.argument() instanceof Literal && c.argument().dataType().isInteger()) {
if (c.field() instanceof Literal && c.field().dataType().isInteger()) {
return true;
}
}

View File

@ -5,9 +5,6 @@
*/
package org.elasticsearch.xpack.sql.expression;
import java.util.List;
import java.util.Locale;
import org.elasticsearch.xpack.sql.capabilities.Resolvable;
import org.elasticsearch.xpack.sql.capabilities.Resolvables;
import org.elasticsearch.xpack.sql.tree.Location;
@ -16,6 +13,9 @@ import org.elasticsearch.xpack.sql.tree.NodeUtils;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.util.List;
import java.util.Locale;
import static java.lang.String.format;
public abstract class Expression extends Node<Expression> implements Resolvable {
@ -78,7 +78,7 @@ public abstract class Expression extends Node<Expression> implements Resolvable
return lazyChildrenResolved;
}
public TypeResolution typeResolved() {
public final TypeResolution typeResolved() {
if (lazyTypeResolution == null) {
lazyTypeResolution = resolveType();
}

View File

@ -40,6 +40,15 @@ public abstract class Expressions {
return false;
}
public static boolean nullable(List<? extends Expression> exps) {
for (Expression exp : exps) {
if (!exp.nullable()) {
return false;
}
}
return true;
}
public static AttributeSet references(List<? extends Expression> exps) {
if (exps.isEmpty()) {
return AttributeSet.EMPTY;

View File

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypeConvertion;
import org.elasticsearch.xpack.sql.type.DataTypes;
import java.util.ArrayList;
import java.util.List;
public abstract class Foldables {
public static <T> T valueOf(Expression e, DataType to) {
if (e.foldable()) {
return DataTypeConvertion.convert(e.fold(), e.dataType(), to);
}
throw new SqlIllegalArgumentException("Cannot determine value for %s", e);
}
public static Object valueOf(Expression e) {
if (e.foldable()) {
return e.fold();
}
throw new SqlIllegalArgumentException("Cannot determine value for %s", e);
}
public static String stringValueOf(Expression e) {
return valueOf(e, DataTypes.KEYWORD);
}
public static Integer intValueOf(Expression e) {
return valueOf(e, DataTypes.INTEGER);
}
public static Long longValueOf(Expression e) {
return valueOf(e, DataTypes.LONG);
}
public static double doubleValueOf(Expression e) {
return valueOf(e, DataTypes.DOUBLE);
}
public static <T> List<T> valuesOf(List<Expression> list, DataType to) {
List<T> l = new ArrayList<>();
for (Expression e : list) {
if (e.foldable()) {
l.add(DataTypeConvertion.convert(e.fold(), e.dataType(), to));
}
else {
throw new SqlIllegalArgumentException("Cannot determine value for %s", e);
}
}
return l;
}
public static List<Double> doubleValuesOf(List<Expression> list) {
return valuesOf(list, DataTypes.DOUBLE);
}
}

View File

@ -18,9 +18,4 @@ public abstract class LeafExpression extends Expression {
public AttributeSet references() {
return AttributeSet.EMPTY;
}
@Override
public int hashCode() {
throw new UnsupportedOperationException();
}
}

View File

@ -5,14 +5,13 @@
*/
package org.elasticsearch.xpack.sql.expression;
import java.util.Collections;
import java.util.Objects;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
public class Literal extends Expression {
import java.util.Objects;
public class Literal extends LeafExpression {
private final Object value;
private final DataType dataType;
@ -21,7 +20,7 @@ public class Literal extends Expression {
public static final Literal FALSE = Literal.of(Location.EMPTY, Boolean.FALSE);
public Literal(Location location, Object value, DataType dataType) {
super(location, Collections.emptyList());
super(location);
this.value = value;
this.dataType = dataType;
}

View File

@ -7,22 +7,18 @@ package org.elasticsearch.xpack.sql.expression.function;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.aware.DistinctAware;
import org.elasticsearch.xpack.sql.expression.function.aware.TimeZoneAware;
import org.elasticsearch.xpack.sql.parser.ParsingException;
import org.elasticsearch.xpack.sql.session.SqlSettings;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.tree.Node;
import org.elasticsearch.xpack.sql.tree.NodeUtils;
import org.elasticsearch.xpack.sql.tree.NodeUtils.NodeInfo;
import org.elasticsearch.xpack.sql.util.Assert;
import org.elasticsearch.xpack.sql.util.StringUtils;
import org.joda.time.DateTimeZone;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
@ -98,46 +94,61 @@ abstract class AbstractFunctionRegistry implements FunctionRegistry {
return StringUtils.camelCaseToUnderscore(name);
}
//
// Instantiates a function through reflection.
// Picks up the constructor by expecting to be of type (Location,Expression) or (Location,List<Expression>) depending on the size of given children, parameters.
// If the function has certain 'aware'-ness (based on the interface implemented), the appropriate types are added to the signature
@SuppressWarnings("rawtypes")
private static Function createInstance(Class<? extends Function> clazz, UnresolvedFunction ur, SqlSettings settings) {
NodeInfo info = NodeUtils.info((Class<? extends Node>) clazz);
Class<?> exp = ur.children().size() == 1 ? Expression.class : List.class;
Object expVal = exp == Expression.class ? ur.children().get(0) : ur.children();
Class<?>[] pTypes = info.ctr.getParameterTypes();
boolean noExpression = false;
boolean distinctAware = DistinctAware.class.isAssignableFrom(clazz);
boolean timezoneAware = TimeZoneAware.class.isAssignableFrom(clazz);
// check constructor signature
// constructor types - location - distinct? - timezone?
int expectedParamCount = pTypes.length - (1 + (distinctAware ? 1 : 0) + (timezoneAware ? 1 : 0));
// check constructor signature
if (ur.children().size() != expectedParamCount) {
List<String> expected = new ArrayList<>();
for (int i = 1; i < expectedParamCount; i++) {
expected.add(pTypes[i].getSimpleName());
}
throw new ParsingException(ur.location(), "Invalid number of arguments given to function [%s], expected %d argument(s):%s but received %d:%s",
ur.name(), expected.size(), expected.toString(), ur.children().size(), ur.children());
}
// validate distinct ctor
if (!distinctAware && ur.distinct()) {
throw new ParsingException(ur.location(), "Function [%s] does not support DISTINCT yet it was specified", ur.name());
}
List<Class> ctorSignature = new ArrayList<>();
ctorSignature.add(Location.class);
// might be a constant function
if (expVal instanceof List && ((List) expVal).isEmpty()) {
noExpression = Arrays.equals(new Class[] { Location.class }, info.ctr.getParameterTypes());
}
else {
ctorSignature.add(exp);
}
// aware stuff
if (distinctAware) {
ctorSignature.add(boolean.class);
}
if (timezoneAware) {
ctorSignature.add(DateTimeZone.class);
}
// validate
Assert.isTrue(Arrays.equals(ctorSignature.toArray(new Class[ctorSignature.size()]), info.ctr.getParameterTypes()),
"No constructor with signature %s found for [%s]", ctorSignature, clazz.getTypeName());
// List<Class> ctorSignature = new ArrayList<>();
// ctorSignature.add(Location.class);
//
// // might be a constant function
// if (expVal instanceof List && ((List) expVal).isEmpty()) {
// noExpression = Arrays.equals(new Class[] { Location.class }, info.ctr.getParameterTypes());
// }
// else {
// ctorSignature.add(exp);
// }
//
// // aware stuff
// if (distinctAware) {
// ctorSignature.add(boolean.class);
// }
// if (timezoneAware) {
// ctorSignature.add(DateTimeZone.class);
// }
//
// // validate
// Assert.isTrue(Arrays.equals(ctorSignature.toArray(new Class[ctorSignature.size()]), info.ctr.getParameterTypes()),
// "No constructor with signature %s found for [%s], found %s instead", ctorSignature, clazz.getTypeName(), info.ctr);
// now add the actual values
try {
@ -147,15 +158,15 @@ abstract class AbstractFunctionRegistry implements FunctionRegistry {
args.add(ur.location());
// has multiple arguments
if (!noExpression) {
args.add(expVal);
if (distinctAware) {
args.add(ur.distinct());
}
if (timezoneAware) {
args.add(settings.timeZone());
}
args.addAll(ur.children());
if (distinctAware) {
args.add(ur.distinct());
}
if (timezoneAware) {
args.add(settings.timeZone());
}
return (Function) info.ctr.newInstance(args.toArray());
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) {
throw new SqlIllegalArgumentException(ex, "Cannot create instance of function %s", ur.name());

View File

@ -13,6 +13,8 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Mean;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness;
import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
@ -92,8 +94,11 @@ public class DefaultFunctionRegistry extends AbstractFunctionRegistry {
VarPop.class,
SumOfSquares.class,
Skewness.class,
Kurtosis.class
// TODO: add multi arg functions like Covariance, Correlate, Percentiles and percentiles rank
Kurtosis.class,
Percentile.class,
PercentileRank.class
// TODO: add multi arg functions like Covariance, Correlate
);
}

View File

@ -8,20 +8,34 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import java.util.List;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
public abstract class AggregateFunction extends Function {
private final Expression argument;
private final Expression field;
private final List<Expression> arguments;
AggregateFunction(Location location, Expression child) {
super(location, singletonList(child));
this.argument = child;
AggregateFunction(Location location, Expression field) {
this(location, field, emptyList());
}
public Expression argument() {
return argument;
AggregateFunction(Location location, Expression field, List<Expression> arguments) {
super(location, CollectionUtils.combine(singletonList(field), arguments));
this.field = field;
this.arguments = arguments;
}
public Expression field() {
return field;
}
public List<Expression> arguments() {
return arguments;
}
public String functionId() {

View File

@ -8,11 +8,17 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import java.util.List;
// marker type for compound aggregates, that is aggregate that provide multiple values (like Stats or Matrix)
// and thus cannot be used directly in SQL and are mainly for internal use
public abstract class CompoundAggregate extends NumericAggregate {
public abstract class CompoundNumericAggregate extends NumericAggregate {
public CompoundAggregate(Location location, Expression argument) {
super(location, argument);
CompoundNumericAggregate(Location location, Expression field, List<Expression> arguments) {
super(location, field, arguments);
}
CompoundNumericAggregate(Location location, Expression field) {
super(location, field);
}
}

View File

@ -12,12 +12,12 @@ import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
public class Count extends AggregateFunction implements DistinctAware {
public class Count extends NumericAggregate implements DistinctAware {
private final boolean distinct;
public Count(Location location, Expression argument, boolean distinct) {
super(location, argument);
public Count(Location location, Expression field, boolean distinct) {
super(location, field);
this.distinct = distinct;
}
@ -30,13 +30,12 @@ public class Count extends AggregateFunction implements DistinctAware {
return DataTypes.LONG;
}
@Override
public String functionId() {
String functionId = id().toString();
// if count works against a given expression, use its id (to identify the group)
if (argument() instanceof NamedExpression) {
functionId = ((NamedExpression) argument()).id().toString();
if (field() instanceof NamedExpression) {
functionId = ((NamedExpression) field()).id().toString();
}
return functionId;
}

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
// Agg 'enclosed' by another agg. Used for agg that return multiple embedded aggs (like MatrixStats)
public interface EnclosedAgg {
String innerName();

View File

@ -8,9 +8,9 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class ExtendedStats extends CompoundAggregate {
public class ExtendedStats extends CompoundNumericAggregate {
public ExtendedStats(Location location, Expression argument) {
super(location, argument);
public ExtendedStats(Location location, Expression field) {
super(location, field);
}
}

View File

@ -13,17 +13,17 @@ import org.elasticsearch.xpack.sql.type.DataType;
public class InnerAggregate extends AggregateFunction {
private final AggregateFunction inner;
private final CompoundAggregate outer;
private final CompoundNumericAggregate outer;
private final String innerId;
// used when the result needs to be extracted from a map (like in MatrixAggs)
// used when the result needs to be extracted from a map (like in MatrixAggs or Percentiles)
private final Expression innerKey;
public InnerAggregate(AggregateFunction inner, CompoundAggregate outer) {
public InnerAggregate(AggregateFunction inner, CompoundNumericAggregate outer) {
this(inner, outer, null);
}
public InnerAggregate(AggregateFunction inner, CompoundAggregate outer, Expression innerKey) {
super(inner.location(), outer.argument());
public InnerAggregate(AggregateFunction inner, CompoundNumericAggregate outer, Expression innerKey) {
super(inner.location(), outer.field(), outer.arguments());
this.inner = inner;
this.outer = outer;
this.innerId = ((EnclosedAgg) inner).innerName();
@ -34,7 +34,7 @@ public class InnerAggregate extends AggregateFunction {
return inner;
}
public CompoundAggregate outer() {
public CompoundNumericAggregate outer() {
return outer;
}

View File

@ -8,9 +8,9 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class MatrixStats extends CompoundAggregate {
public class MatrixStats extends CompoundNumericAggregate {
public MatrixStats(Location location, Expression argument) {
super(location, argument);
public MatrixStats(Location location, Expression field) {
super(location, field);
}
}

View File

@ -17,7 +17,7 @@ public class Max extends NumericAggregate implements EnclosedAgg {
@Override
public DataType dataType() {
return argument().dataType();
return field().dataType();
}
@Override

View File

@ -17,7 +17,7 @@ public class Min extends NumericAggregate implements EnclosedAgg {
@Override
public DataType dataType() {
return argument().dataType();
return field().dataType();
}
@Override

View File

@ -11,17 +11,24 @@ import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
abstract class NumericAggregate extends AggregateFunction {
import java.util.List;
NumericAggregate(Location location, Expression argument) {
super(location, argument);
class NumericAggregate extends AggregateFunction {
NumericAggregate(Location location, Expression field, List<Expression> arguments) {
super(location, field, arguments);
}
NumericAggregate(Location location, Expression field) {
super(location, field);
}
@Override
protected TypeResolution resolveType() {
return argument().dataType().isNumeric() ?
TypeResolution.TYPE_RESOLVED :
new TypeResolution("Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')", functionName(), Expressions.name(argument()), argument().dataType().esName());
return field().dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED : new TypeResolution(
"Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')", functionName(),
Expressions.name(field()), field().dataType().esName());
}
@Override

View File

@ -1,30 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
abstract class NumericAggregateFunction extends AggregateFunction {
NumericAggregateFunction(Location location, Expression argument) {
super(location, argument);
}
@Override
protected TypeResolution resolveType() {
return argument().dataType().isNumeric() ?
TypeResolution.TYPE_RESOLVED :
new TypeResolution("Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')", functionName(), Expressions.name(argument()), argument().dataType().esName());
}
@Override
public DataType dataType() {
return argument().dataType();
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
import java.util.Objects;
import static java.util.Collections.singletonList;
public class Percentile extends AggregateFunction implements EnclosedAgg {
private final Expression percent;
public Percentile(Location location, Expression field, Expression percent) {
super(location, field, singletonList(percent));
this.percent = percent;
}
@Override
protected TypeResolution resolveType() {
TypeResolution resolution = field().dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED :
new TypeResolution("Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')",
functionName(), Expressions.name(field()), field().dataType().esName());
if (TypeResolution.TYPE_RESOLVED.equals(resolution)) {
resolution = percent().dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED :
new TypeResolution("Percentile#percent argument cannot be non-numeric (type is'%s')", percent().dataType().esName());
}
return resolution;
}
public Expression percent() {
return percent;
}
@Override
public DataType dataType() {
return DataTypes.DOUBLE;
}
@Override
public String innerName() {
return "[" + Double.toString(Foldables.doubleValueOf(percent)) + "]";
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Percentile other = (Percentile) obj;
return Objects.equals(field(), other.field())
&& Objects.equals(percent, other.percent);
}
}

View File

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
import java.util.Objects;
import static java.util.Collections.singletonList;
public class PercentileRank extends AggregateFunction implements EnclosedAgg {
private final Expression value;
public PercentileRank(Location location, Expression field, Expression value) {
super(location, field, singletonList(value));
this.value = value;
}
@Override
protected TypeResolution resolveType() {
TypeResolution resolution = field().dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED :
new TypeResolution("Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')",
functionName(), Expressions.name(field()), field().dataType().esName());
if (TypeResolution.TYPE_RESOLVED.equals(resolution)) {
resolution = value.dataType().isNumeric() ? TypeResolution.TYPE_RESOLVED :
new TypeResolution("PercentileRank#value argument cannot be non-numeric (type is'%s')", value.dataType().esName());
}
return resolution;
}
public Expression value() {
return value;
}
@Override
public DataType dataType() {
return DataTypes.DOUBLE;
}
@Override
public String innerName() {
return "[" + Double.toString(Foldables.doubleValueOf(value)) + "]";
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PercentileRank other = (PercentileRank) obj;
return Objects.equals(field(), other.field())
&& Objects.equals(value, other.value);
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import java.util.List;
import java.util.Objects;
public class PercentileRanks extends CompoundNumericAggregate {
private final List<Expression> values;
public PercentileRanks(Location location, Expression field, List<Expression> values) {
super(location, field, values);
this.values = values;
}
public List<Expression> values() {
return values;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
PercentileRanks other = (PercentileRanks) obj;
return Objects.equals(field(), other.field())
&& Objects.equals(values, other.values);
}
}

View File

@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import java.util.List;
import java.util.Objects;
public class Percentiles extends CompoundNumericAggregate {
private final List<Expression> percents;
public Percentiles(Location location, Expression field, List<Expression> percents) {
super(location, field, percents);
this.percents = percents;
}
public List<Expression> percents() {
return percents;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null || getClass() != obj.getClass()) {
return false;
}
Percentiles other = (Percentiles) obj;
return Objects.equals(field(), other.field())
&& Objects.equals(percents, other.percents);
}
}

View File

@ -8,10 +8,10 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class Stats extends CompoundAggregate {
public class Stats extends CompoundNumericAggregate {
public Stats(Location location, Expression argument) {
super(location, argument);
public Stats(Location location, Expression field) {
super(location, field);
}
public static boolean isTypeCompatible(Expression e) {

View File

@ -17,7 +17,7 @@ public class Sum extends NumericAggregate implements EnclosedAgg {
@Override
public DataType dataType() {
return argument().dataType();
return field().dataType();
}
@Override

View File

@ -27,6 +27,10 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStatsEn
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentiles;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
import org.elasticsearch.xpack.sql.expression.predicate.And;
@ -50,7 +54,6 @@ import org.elasticsearch.xpack.sql.plan.logical.SubQueryAlias;
import org.elasticsearch.xpack.sql.rule.Rule;
import org.elasticsearch.xpack.sql.rule.RuleExecutor;
import org.elasticsearch.xpack.sql.session.EmptyExecutable;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
@ -71,6 +74,8 @@ import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.inComm
import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.splitAnd;
import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.splitOr;
import static org.elasticsearch.xpack.sql.expression.predicate.Predicates.subtract;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
public class Optimizer extends RuleExecutor<LogicalPlan> {
@ -101,7 +106,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
new CombineAggsToMatrixStats(),
new CombineAggsToExtendedStats(),
new CombineAggsToStats(),
new PromoteStatsToExtendedStats()
new PromoteStatsToExtendedStats(), new CombineAggsToPercentiles(), new CombineAggsToPercentileRanks()
);
Batch cleanup = new Batch("Operator Optimization",
@ -254,7 +259,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (e instanceof MatrixStatsEnclosed) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Expression argument = f.field();
MatrixStats matrixStats = seen.get(argument);
if (matrixStats == null) {
@ -262,7 +267,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
seen.put(argument, matrixStats);
}
InnerAggregate ia = new InnerAggregate(f, matrixStats, f.argument());
InnerAggregate ia = new InnerAggregate(f, matrixStats, f.field());
promotedIds.putIfAbsent(f.functionId(), ia.toAttribute());
return ia;
}
@ -276,7 +281,6 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
Map<Expression, ExtendedStats> seen = new LinkedHashMap<>();
p = p.transformExpressionsUp(e -> rule(e, seen, promotedFunctionIds));
// update old agg attributes
@ -292,7 +296,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (e instanceof ExtendedStatsEnclosed) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Expression argument = f.field();
ExtendedStats extendedStats = seen.get(argument);
if (extendedStats == null) {
@ -343,7 +347,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (Stats.isTypeCompatible(e)) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Expression argument = f.field();
Counter counter = seen.get(argument);
if (counter == null) {
@ -365,7 +369,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (Stats.isTypeCompatible(e)) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Expression argument = f.field();
Counter counter = seen.get(argument);
// if the stat has at least two different functions for it, promote it as stat
@ -412,7 +416,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
InnerAggregate ia = (InnerAggregate) e;
if (ia.outer() instanceof ExtendedStats) {
ExtendedStats extStats = (ExtendedStats) ia.outer();
seen.putIfAbsent(extStats.argument(), extStats);
seen.putIfAbsent(extStats.field(), extStats);
}
}
}
@ -422,8 +426,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
InnerAggregate ia = (InnerAggregate) e;
if (ia.outer() instanceof Stats) {
Stats stats = (Stats) ia.outer();
ExtendedStats ext = seen.get(stats.argument());
if (ext != null && stats.argument().equals(ext.argument())) {
ExtendedStats ext = seen.get(stats.field());
if (ext != null && stats.field().equals(ext.field())) {
return new InnerAggregate(ia.inner(), ext);
}
}
@ -433,6 +437,119 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
}
static class CombineAggsToPercentiles extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
// percentile per field/expression
Map<Expression, Set<Expression>> percentsPerField = new LinkedHashMap<>();
// count gather the percents for each field
p.forEachExpressionsUp(e -> count(e, percentsPerField));
Map<Expression, Percentiles> percentilesPerField = new LinkedHashMap<>();
// create a Percentile agg for each field (and its associated percents)
percentsPerField.forEach((k, v) -> {
percentilesPerField.put(k, new Percentiles(v.iterator().next().location(), k, new ArrayList<>(v)));
});
// now replace the agg with pointer to the main ones
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
p = p.transformExpressionsUp(e -> rule(e, percentilesPerField, promotedFunctionIds));
// finally update all the function references as well
return p.transformExpressionsDown(e -> CombineAggsToStats.updateFunctionAttrs(e, promotedFunctionIds));
}
private void count(Expression e, Map<Expression, Set<Expression>> percentsPerField) {
if (e instanceof Percentile) {
Percentile p = (Percentile) e;
Expression field = p.field();
Set<Expression> percentiles = percentsPerField.get(field);
if (percentiles == null) {
percentiles = new LinkedHashSet<>();
percentsPerField.put(field, percentiles);
}
percentiles.add(p.percent());
}
}
protected Expression rule(Expression e, Map<Expression, Percentiles> percentilesPerField, Map<String, AggregateFunctionAttribute> promotedIds) {
if (e instanceof Percentile) {
Percentile p = (Percentile) e;
Percentiles percentiles = percentilesPerField.get(p.field());
InnerAggregate ia = new InnerAggregate(p, percentiles);
promotedIds.putIfAbsent(p.functionId(), ia.toAttribute());
return ia;
}
return e;
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
}
static class CombineAggsToPercentileRanks extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
// percentile per field/expression
Map<Expression, Set<Expression>> valuesPerField = new LinkedHashMap<>();
// count gather the percents for each field
p.forEachExpressionsUp(e -> count(e, valuesPerField));
Map<Expression, PercentileRanks> ranksPerField = new LinkedHashMap<>();
// create a PercentileRanks agg for each field (and its associated values)
valuesPerField.forEach((k, v) -> {
ranksPerField.put(k, new PercentileRanks(v.iterator().next().location(), k, new ArrayList<>(v)));
});
// now replace the agg with pointer to the main ones
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
p = p.transformExpressionsUp(e -> rule(e, ranksPerField, promotedFunctionIds));
// finally update all the function references as well
return p.transformExpressionsDown(e -> CombineAggsToStats.updateFunctionAttrs(e, promotedFunctionIds));
}
private void count(Expression e, Map<Expression, Set<Expression>> ranksPerField) {
if (e instanceof PercentileRank) {
PercentileRank p = (PercentileRank) e;
Expression field = p.field();
Set<Expression> percentiles = ranksPerField.get(field);
if (percentiles == null) {
percentiles = new LinkedHashSet<>();
ranksPerField.put(field, percentiles);
}
percentiles.add(p.value());
}
}
protected Expression rule(Expression e, Map<Expression, PercentileRanks> ranksPerField, Map<String, AggregateFunctionAttribute> promotedIds) {
if (e instanceof PercentileRank) {
PercentileRank p = (PercentileRank) e;
PercentileRanks ranks = ranksPerField.get(p.field());
InnerAggregate ia = new InnerAggregate(p, ranks);
promotedIds.putIfAbsent(p.functionId(), ia.toAttribute());
return ia;
}
return e;
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
}
static class PruneFilters extends OptimizerRule<Filter> {
@ -783,7 +900,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// (a || b || c || ... ) && (a || b || d || ... ) => ((c || ...) && (d || ...)) || a || b
Expression combineLeft = combineOr(lDiff);
Expression combineRight = combineOr(rDiff);
return combineOr(CollectionUtils.combine(common, new And(combineLeft.location(), combineLeft, combineRight)));
return combineOr(combine(common, new And(combineLeft.location(), combineLeft, combineRight)));
}
if (bc instanceof Or) {
@ -821,7 +938,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// (a || b || c || ... ) && (a || b || d || ... ) => ((c || ...) && (d || ...)) || a || b
Expression combineLeft = combineAnd(lDiff);
Expression combineRight = combineAnd(rDiff);
return combineAnd(CollectionUtils.combine(common, new Or(combineLeft.location(), combineLeft, combineRight)));
return combineAnd(combine(common, new Or(combineLeft.location(), combineLeft, combineRight)));
}
// TODO: eliminate conjunction/disjunction

View File

@ -90,14 +90,6 @@ abstract class ExpressionBuilder extends IdentifierBuilder {
return expression(ctx.expression());
}
@Override
// Star will be resolved by the analyzer - currently just put a placeholder
public Expression visitStar(StarContext ctx) {
return new UnresolvedStar(source(ctx), ctx.qualifier != null ? visitColumnExpression(ctx.qualifier) : null);
}
@Override
public Expression visitSelectExpression(SelectExpressionContext ctx) {
Expression exp = expression(ctx.expression());
@ -108,6 +100,11 @@ abstract class ExpressionBuilder extends IdentifierBuilder {
return exp;
}
@Override
public Expression visitStar(StarContext ctx) {
return new UnresolvedStar(source(ctx), ctx.qualifier != null ? visitColumnExpression(ctx.qualifier) : null);
}
@Override
public Object visitDereference(DereferenceContext ctx) {
String fieldName = visitIdentifier(ctx.fieldName);
@ -286,6 +283,7 @@ abstract class ExpressionBuilder extends IdentifierBuilder {
if (ctx.setQuantifier() != null) {
isDistinct = (ctx.setQuantifier().DISTINCT() != null);
}
return new UnresolvedFunction(source(ctx), name, isDistinct, expressions(ctx.expression()));
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.ExpressionId;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.NestedFieldAttribute;
import org.elasticsearch.xpack.sql.expression.Order;
@ -17,7 +18,7 @@ import org.elasticsearch.xpack.sql.expression.RootFieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundNumericAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
@ -234,7 +235,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
Map<Attribute, Attribute> aliases = new LinkedHashMap<>();
// tracker for compound aggs seen in a group
Map<CompoundAggregate, String> compoundAggMap = new LinkedHashMap<>();
Map<CompoundNumericAggregate, String> compoundAggMap = new LinkedHashMap<>();
// followed by actual aggregates
for (NamedExpression ne : a.aggregates()) {
@ -338,7 +339,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
return queryC.addAggRef(aggInfo.propertyPath(), proc);
}
private QueryContainer addFunction(GroupingAgg parentAgg, AggregateFunction f, ColumnProcessor proc, Map<CompoundAggregate, String> compoundAggMap, QueryContainer queryC) {
private QueryContainer addFunction(GroupingAgg parentAgg, AggregateFunction f, ColumnProcessor proc, Map<CompoundNumericAggregate, String> compoundAggMap, QueryContainer queryC) {
String functionId = f.functionId();
// handle count as a special case agg
if (f instanceof Count) {
@ -354,7 +355,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (f instanceof InnerAggregate) {
InnerAggregate ia = (InnerAggregate) f;
CompoundAggregate outer = ia.outer();
CompoundNumericAggregate outer = ia.outer();
String cAggPath = compoundAggMap.get(outer);
// the compound agg hasn't been seen before so initialize it
@ -449,7 +450,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
protected PhysicalPlan rule(LimitExec plan) {
if (plan.child() instanceof EsQueryExec) {
EsQueryExec exec = (EsQueryExec) plan.child();
int limit = Integer.valueOf(QueryTranslator.valueOf(plan.limit()));
int limit = Foldables.intValueOf(plan.limit());
int currentSize = exec.queryContainer().limit();
int newSize = currentSize < 0 ? limit : Math.min(currentSize, limit);
return exec.with(exec.queryContainer().withLimit(newSize));

View File

@ -20,12 +20,14 @@ import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg;
import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundNumericAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentiles;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
@ -62,6 +64,8 @@ import org.elasticsearch.xpack.sql.querydsl.agg.MatrixStatsAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.MaxAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.MinAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.OrAggFilter;
import org.elasticsearch.xpack.sql.querydsl.agg.PercentileRanksAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.PercentilesAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.StatsAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.SumAgg;
import org.elasticsearch.xpack.sql.querydsl.query.AndQuery;
@ -93,6 +97,9 @@ import java.util.Map.Entry;
import static java.lang.String.format;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.sql.expression.Foldables.doubleValuesOf;
import static org.elasticsearch.xpack.sql.expression.Foldables.stringValueOf;
import static org.elasticsearch.xpack.sql.expression.Foldables.valueOf;
import static org.elasticsearch.xpack.sql.expression.function.scalar.script.ParamsBuilder.paramsBuilder;
import static org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate.formatTemplate;
@ -117,6 +124,8 @@ abstract class QueryTranslator {
new StatsAggs(),
new ExtendedStatsAggs(),
new MatrixStatsAggs(),
new PercentilesAggs(),
new PercentileRanksAggs(),
new DistinctCounts(),
new DateTimes()
);
@ -203,8 +212,8 @@ abstract class QueryTranslator {
if (groupPath != null) {
GroupingAgg matchingGroup = null;
// group found - finding the dedicated agg
if (f.argument() instanceof NamedExpression) {
matchingGroup = groupMap.get(((NamedExpression) f.argument()).id());
if (f.field() instanceof NamedExpression) {
matchingGroup = groupMap.get(((NamedExpression) f.field()).id());
}
// return matching group or the tail (last group)
return matchingGroup != null ? matchingGroup : tail;
@ -358,11 +367,6 @@ abstract class QueryTranslator {
return new NotQuery(query.location(), query);
}
@SuppressWarnings("unchecked")
static <T> T valueOf(Expression e) {
return (T) ((Literal) e).value();
}
static String nameOf(Expression e) {
if (e instanceof DateTimeFunction) {
return nameOf(((DateTimeFunction) e).argument());
@ -370,6 +374,9 @@ abstract class QueryTranslator {
if (e instanceof NamedExpression) {
return ((NamedExpression) e).name();
}
if (e instanceof Literal) {
return String.valueOf(e.fold());
}
throw new SqlIllegalArgumentException("Cannot determine name for %s", e);
}
@ -394,7 +401,7 @@ abstract class QueryTranslator {
}
static String field(AggregateFunction af) {
Expression arg = af.argument();
Expression arg = af.field();
if (arg instanceof RootFieldAttribute) {
return ((RootFieldAttribute) arg).name();
}
@ -420,7 +427,7 @@ abstract class QueryTranslator {
target = nameOf(analyzed ? fa : fa.notAnalyzedAttribute());
}
String pattern = sqlToEsPatternMatching(valueOf(e.right()));
String pattern = sqlToEsPatternMatching(stringValueOf(e.right()));
if (e instanceof Like) {
if (analyzed) {
q = new QueryStringQuery(e.location(), pattern, target);
@ -435,7 +442,7 @@ abstract class QueryTranslator {
q = new QueryStringQuery(e.location(), "/" + pattern + "/", target);
}
else {
q = new RegexQuery(e.location(), nameOf(e.left()), sqlToEsPatternMatching(valueOf(e.right())));
q = new RegexQuery(e.location(), nameOf(e.left()), sqlToEsPatternMatching(stringValueOf(e.right())));
}
}
@ -765,6 +772,22 @@ abstract class QueryTranslator {
}
}
static class PercentilesAggs extends CompoundAggTranslator<Percentiles> {
@Override
protected LeafAgg toAgg(String id, String path, Percentiles p) {
return new PercentilesAgg(id, path, field(p), doubleValuesOf(p.percents()));
}
}
static class PercentileRanksAggs extends CompoundAggTranslator<PercentileRanks> {
@Override
protected LeafAgg toAgg(String id, String path, PercentileRanks p) {
return new PercentileRanksAgg(id, path, field(p), doubleValuesOf(p.values()));
}
}
static class DateTimes extends SingleValueAggTranslator<Min> {
@Override
@ -796,7 +819,7 @@ abstract class QueryTranslator {
protected abstract LeafAgg toAgg(String id, String path, F f);
}
abstract static class CompoundAggTranslator<C extends CompoundAggregate> extends AggTranslator<C> {
abstract static class CompoundAggTranslator<C extends CompoundNumericAggregate> extends AggTranslator<C> {
@Override
protected final LeafAgg asAgg(String id, String parent, C function) {

View File

@ -41,7 +41,8 @@ public abstract class AggPath {
}
public static String metricValue(String aggPath, String valueName) {
return aggPath + VALUE_DELIMITER + valueName;
// handle aggPath inconsistency (for percentiles and percentileRanks) percentile[99.9] (valid) vs percentile.99.9 (invalid)
return valueName.startsWith("[") ? aggPath + valueName : aggPath + VALUE_DELIMITER + valueName;
}
public static String path(String parent, String child) {

View File

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import java.util.List;
import static org.elasticsearch.search.aggregations.AggregationBuilders.percentileRanks;
public class PercentileRanksAgg extends LeafAgg {
private final List<Double> values;
public PercentileRanksAgg(String id, String propertyPath, String fieldName, List<Double> values) {
super(id, propertyPath, fieldName);
this.values = values;
}
public List<Double> percents() {
return values;
}
@Override
AggregationBuilder toBuilder() {
// TODO: look at keyed
return percentileRanks(id())
.field(fieldName())
.values(values.stream().mapToDouble(Double::doubleValue).toArray());
}
}

View File

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import java.util.List;
import static org.elasticsearch.search.aggregations.AggregationBuilders.percentiles;
public class PercentilesAgg extends LeafAgg {
private final List<Double> percents;
public PercentilesAgg(String id, String propertyPath, String fieldName, List<Double> percents) {
super(id, propertyPath, fieldName);
this.percents = percents;
}
public List<Double> percents() {
return percents;
}
@Override
AggregationBuilder toBuilder() {
// TODO: look at keyed
return percentiles(id())
.field(fieldName())
.percentiles(percents.stream().mapToDouble(Double::doubleValue).toArray());
}
}

View File

@ -5,12 +5,12 @@
*/
package org.elasticsearch.xpack.sql.type;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import java.sql.JDBCType;
import java.util.LinkedHashMap;
import java.util.Map;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
public abstract class DataTypes {
public static final DataType NULL = new NullType();
@ -104,6 +104,9 @@ public abstract class DataTypes {
if (value instanceof Short) {
return SHORT;
}
if (value instanceof String) {
return KEYWORD;
}
throw new SqlIllegalArgumentException("No idea what's the DataType for %s", value.getClass());
}

View File

@ -74,7 +74,15 @@ public abstract class CollectionUtils {
return map;
}
public static <T> List<T> combine(Collection<? extends T> left, Collection<? extends T> right) {
@SuppressWarnings("unchecked")
public static <T> List<T> combine(List<? extends T> left, List<? extends T> right) {
if (right.isEmpty()) {
return (List<T>) left;
}
if (left.isEmpty()) {
return (List<T>) right;
}
List<T> list = new ArrayList<>(left.size() + right.size());
if (!left.isEmpty()) {
list.addAll(left);