SQL: implement SUM, MIN, MAX, AVG over literals (#56786) (#56850)

* Adds support for MIN, MAX, AVG, SUM aggregates acting on literals.
SELECT SUM(1) FROM index
and
SELECT SUM(1), AVG(2)
work both on indices and as local execution.

(cherry picked from commit efb72907c0391612c4a2b6256e327060b4167912)
This commit is contained in:
Andrei Stefan 2020-05-16 02:13:55 +03:00 committed by GitHub
parent 813609b47c
commit 4d47d63f55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 217 additions and 5 deletions

View File

@ -95,7 +95,6 @@ public abstract class ScalarFunction extends Function {
return new ScriptTemplate(processScript("{sql}.stWktToSql({})"), paramsBuilder().variable(fold.toString()).build(), dataType());
}
return new ScriptTemplate(processScript("{}"),
paramsBuilder().variable(fold).build(),
dataType());

View File

@ -175,6 +175,91 @@ F |F |1666196|1666196 |1666196
M |M |2671054|2671054 |2671054
;
sumLiteralWithTrueConditionAndHavingWithCount
SELECT SUM(1) AS c FROM test_emp WHERE 'a'='a' HAVING COUNT(1) > 0;
c:i
---------------
100
;
sumLiteralWithTwoConditionsAndGroupByField
SELECT SUM(10) AS s10, birth_date, SUM(1) AS c FROM test_emp WHERE (birth_date >= {ts '1959-01-01 00:00:00'}) AND (birth_date <= {ts '1959-12-31 23:59:59'}) GROUP BY 2;
s10:l | birth_date:ts | c:l
---------------+------------------------+---------------
10 |1959-01-27T00:00:00.000Z|1
10 |1959-04-07T00:00:00.000Z|1
20 |1959-07-23T00:00:00.000Z|2
10 |1959-08-10T00:00:00.000Z|1
10 |1959-08-19T00:00:00.000Z|1
10 |1959-10-01T00:00:00.000Z|1
10 |1959-12-03T00:00:00.000Z|1
10 |1959-12-25T00:00:00.000Z|1
;
sumLiteralWithGroupByAndTwoConditionsOnSum
SELECT first_name, SUM(1) AS c FROM test_emp GROUP BY 1 HAVING ((SUM(1) >= 0) AND (SUM(1) <= 577)) LIMIT 10;
first_name:s | c:l
---------------+---------------
null |10
Alejandro |1
Amabile |1
Anneke |1
Anoosh |1
Arumugam |1
Basil |1
Berhard |1
Berni |1
Bezalel |1
;
sumFieldWithSumLiteralAsCondition
SELECT first_name, last_name, SUM(salary) AS s, birth_date AS y, COUNT(1) FROM test_emp GROUP BY 1, 2, 4 HAVING ((SUM(1) >= 1) AND (SUM(1) <= 577)) AND ((SUM(salary) >= 35000) AND (SUM(salary) <= 45000));
first_name:s | last_name:s | s:i | y:ts | COUNT(1):l
---------------+---------------+---------------+------------------------+---------------
null |Brender |36051 |1959-10-01T00:00:00.000Z|1
null |Joslin |37716 |1959-01-27T00:00:00.000Z|1
null |Lortz |35222 |1960-07-20T00:00:00.000Z|1
null |Makrucki |37691 |1963-07-22T00:00:00.000Z|1
null |Swan |39878 |1962-12-29T00:00:00.000Z|1
Alejandro |McAlpine |44307 |1953-09-19T00:00:00.000Z|1
Amabile |Gomatam |38645 |1955-10-04T00:00:00.000Z|1
Basil |Tramer |37853 |null |1
Berhard |McFarlin |38376 |1954-10-01T00:00:00.000Z|1
Berni |Genin |37137 |1956-02-12T00:00:00.000Z|1
Chirstian |Koblick |36174 |1954-05-01T00:00:00.000Z|1
Domenick |Tempesti |39356 |1963-11-26T00:00:00.000Z|1
Hilari |Morton |37702 |1965-01-03T00:00:00.000Z|1
Hisao |Lipner |40612 |1958-01-21T00:00:00.000Z|1
Jayson |Mandell |43889 |1954-09-16T00:00:00.000Z|1
Jungsoon |Syrzycki |39638 |1954-02-25T00:00:00.000Z|1
Kendra |Hofting |44956 |1961-05-30T00:00:00.000Z|1
Kenroku |Malabarba |35742 |1962-11-07T00:00:00.000Z|1
Margareta |Bierman |41933 |1960-09-06T00:00:00.000Z|1
Mayuko |Warwick |40031 |1952-12-24T00:00:00.000Z|1
Mingsen |Casley |39728 |null |1
Mokhtar |Bernatsky |38992 |1955-08-28T00:00:00.000Z|1
Saniya |Kalloufi |43906 |1958-02-19T00:00:00.000Z|1
Sreekrishna |Servieres |44817 |1961-09-23T00:00:00.000Z|1
Sudharsan |Flasterstein |43602 |1963-03-21T00:00:00.000Z|1
Vishv |Zockler |39110 |1959-07-23T00:00:00.000Z|1
Weiyi |Meriste |37112 |null |1
Yinghua |Dredge |43026 |1958-05-21T00:00:00.000Z|1
Zvonko |Nyanchama |42716 |null |1
;
mirrorIifForNumericAggregate
SELECT IIF(COUNT(1)=0, NULL, 123)+5, AVG(123), MIN(123)+5, IIF(COUNT(1)=0, NULL, 30*COUNT(1)), SUM(30) FROM test_emp;
IIF(COUNT(1)=0, NULL, 123)+5:i| AVG(123):d | MIN(123)+5:i |IIF(COUNT(1)=0, NULL, 30*COUNT(1)):l| SUM(30):l
------------------------------+-----------------+-----------------+------------------------------------+---------------
128 |123 |128 |3000 |3000
;
aggByComplexCastedValue
SELECT CONVERT(CONCAT(LTRIM(CONVERT("emp_no", SQL_VARCHAR)), LTRIM(CONVERT("languages", SQL_VARCHAR))), SQL_BIGINT) AS "TEMP"
FROM "test_emp" GROUP BY "TEMP" ORDER BY "TEMP" LIMIT 20;

View File

@ -178,6 +178,22 @@ SELECT gender g, languages l, COUNT(*) c FROM "test_emp" GROUP BY g, l ORDER BY
aggCountDistinctWithAliasAndGroupBy
SELECT COUNT(*) cnt, COUNT(DISTINCT first_name) as names, gender FROM test_emp GROUP BY gender ORDER BY gender;
localSum
SELECT CAST(SUM(1) AS BIGINT);
localSumWithAlias
SELECT CAST(SUM(1) AS BIGINT) AS s, CAST(SUM(1) AS BIGINT);
localMax
SELECT MAX(1);
localAggregates
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
aggregatesOfLiteralsFromIndex
SELECT MAX(1), MIN(1), CAST(SUM(1) AS BIGINT), CAST(AVG(1) AS INTEGER), COUNT(1) FROM test_emp;
aggregatesOfLiteralsFromIndex_WithNoMatchingFilter
SELECT MAX(1), MIN(1), CAST(SUM(1) AS BIGINT), CAST(AVG(1) AS INTEGER), COUNT(1) FROM test_emp WHERE gender='123';
sumOfLiteralInHavingOnly
SELECT gender, COUNT(*) FROM test_emp GROUP BY gender HAVING SUM(10) > 200 ORDER BY gender;
sumLiteralAndSumFieldWithComplexHaving
SELECT gender, CAST(SUM("salary") AS BIGINT), CAST(SUM(1) AS BIGINT), CAST(SUM(10) AS BIGINT), COUNT(*) FROM test_emp GROUP BY gender HAVING ((SUM(1) >= 0) AND (SUM(1) <= 50) AND (SUM(salary) >= 250000) AND (SUM(salary) <= 5000000)) ORDER BY gender;
// Conditional COUNT

View File

@ -16,7 +16,7 @@ import java.util.List;
import static org.elasticsearch.xpack.ql.expression.TypeResolutions.isNumeric;
abstract class NumericAggregate extends AggregateFunction {
public abstract class NumericAggregate extends AggregateFunction {
NumericAggregate(Source source, Expression field, List<Expression> parameters) {
super(source, field, parameters);

View File

@ -20,6 +20,7 @@ import org.elasticsearch.xpack.ql.expression.ReferenceAttribute;
import org.elasticsearch.xpack.ql.expression.UnresolvedAttribute;
import org.elasticsearch.xpack.ql.expression.function.Function;
import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.ql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.ql.expression.function.aggregate.InnerAggregate;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNotNull;
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
@ -56,6 +57,7 @@ import org.elasticsearch.xpack.ql.type.DataTypes;
import org.elasticsearch.xpack.ql.util.Holder;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer.CleanAliases;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg;
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.First;
@ -64,17 +66,21 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentiles;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ArbitraryConditionalFunction;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Case;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Coalesce;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.IfConditional;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.Iif;
import org.elasticsearch.xpack.sql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.logical.LocalRelation;
import org.elasticsearch.xpack.sql.plan.logical.Pivot;
@ -115,7 +121,10 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
new RewritePivot());
Batch refs = new Batch("Replace References", Limiter.ONCE,
new ReplaceReferenceAttributeWithSource());
new ReplaceReferenceAttributeWithSource(),
new ReplaceAggregatesWithLiterals(),
new ReplaceCountInLocalRelation()
);
Batch operators = new Batch("Operator Optimization",
// combining
@ -772,6 +781,52 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
}
/**
* Any numeric aggregates (avg, min, max, sum) acting on literals are converted to an iif(count(1)=0, null, literal*count(1)) for sum,
* and to iif(count(1)=0,null,literal) for the other three.
*/
private static class ReplaceAggregatesWithLiterals extends OptimizerRule<LogicalPlan> {
@Override
protected LogicalPlan rule(LogicalPlan p) {
return p.transformExpressionsDown(e -> {
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum) {
NumericAggregate a = (NumericAggregate) e;
if (a.field().foldable()) {
Expression countOne = new Count(a.source(), new Literal(Source.EMPTY, 1, a.dataType()), false);
Equals countEqZero = new Equals(a.source(), countOne, new Literal(Source.EMPTY, 0, a.dataType()));
Expression argument = a.field();
Literal foldedArgument = new Literal(argument.source(), argument.fold(), a.dataType());
Expression iifElseResult = foldedArgument;
if (e instanceof Sum) {
iifElseResult = new Mul(a.source(), countOne, foldedArgument);
}
return new Iif(a.source(), countEqZero, Literal.NULL, iifElseResult);
}
}
return e;
});
}
}
/**
* A COUNT in a local relation will always be 1.
*/
private static class ReplaceCountInLocalRelation extends OptimizerRule<Aggregate> {
@Override
protected LogicalPlan rule(Aggregate a) {
boolean hasLocalRelation = a.anyMatch(LocalRelation.class::isInstance);
return hasLocalRelation ? a.transformExpressionsDown(c -> {
return c instanceof Count ? new Literal(c.source(), 1, c.dataType()) : c;
}) : a;
}
}
static class ReplaceAggsWithMatrixStats extends OptimizerBasicRule {
@Override
@ -1115,8 +1170,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
} else if (n.foldable()) {
values.add(n.fold());
}
else {
} else {
// not everything is foldable, bail-out early
return values;
}

View File

@ -175,6 +175,64 @@ public class QueryFolderTests extends ESTestCase {
assertThat(ee.output().get(0).toString(), startsWith("E(){r}#"));
}
public void testLocalExecWithAggs() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0)");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
}
public void testLocalExecWithAggsAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT SUM(10) WHERE 2 > 3");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(EmptyExecutable.class, le.executable().getClass());
EmptyExecutable ee = (EmptyExecutable) le.executable();
assertEquals(1, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("SUM(10){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0) WHERE 1 = 1");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderBy() {
PhysicalPlan p = plan("SELECT MAX(23), SUM(1) WHERE 1 = 1 ORDER BY 1, 2 DESC");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MAX(23){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(1){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderByAndLimit() {
PhysicalPlan p = plan("SELECT AVG(10), SUM(2) WHERE 1 = 1 ORDER BY 1, 2 DESC LIMIT 5");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("AVG(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(2){r}#"));
}
public void testFoldingOfIsNull() {
PhysicalPlan p = plan("SELECT keyword FROM test WHERE (keyword IS NOT NULL) IS NULL");
assertEquals(LocalExec.class, p.getClass());