Previously `COUNT(DISTINCT <literal>)` was returning the same result as `COUNT(<literal>)` which is not correct as it should always return 1 if there is at least one matching row (bucket if there is a GROUP BY), or 0 otherwise. (cherry picked from commit 7f7d7562d43034907f432d39d0d66f490d78f4a8)
This commit is contained in:
parent
57c3a61535
commit
644ae49817
|
@ -178,6 +178,8 @@ SELECT gender g, languages l, COUNT(*) c FROM "test_emp" GROUP BY g, l ORDER BY
|
|||
aggCountDistinctWithAliasAndGroupBy
|
||||
SELECT COUNT(*) cnt, COUNT(DISTINCT first_name) as names, gender FROM test_emp GROUP BY gender ORDER BY gender;
|
||||
|
||||
localCount
|
||||
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo');
|
||||
localSum
|
||||
SELECT CAST(SUM(1) AS BIGINT);
|
||||
localSumWithAlias
|
||||
|
@ -185,7 +187,12 @@ SELECT CAST(SUM(1) AS BIGINT) AS s, CAST(SUM(1) AS BIGINT);
|
|||
localMax
|
||||
SELECT MAX(1);
|
||||
localAggregates
|
||||
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
|
||||
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
|
||||
|
||||
countOfLiteralsFromIndex
|
||||
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp;
|
||||
countOfLiteralsFromIndexWithGroupBy
|
||||
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp GROUP BY gender ORDER BY gender;
|
||||
aggregatesOfLiteralsFromIndex
|
||||
SELECT MAX(1), MIN(1), CAST(SUM(1) AS BIGINT), CAST(AVG(1) AS INTEGER), COUNT(1) FROM test_emp;
|
||||
aggregatesOfLiteralsFromIndex_WithNoMatchingFilter
|
||||
|
|
|
@ -66,7 +66,6 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
|
|||
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
|
||||
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
|
||||
|
@ -784,14 +783,17 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
|
|||
/**
|
||||
* Any numeric aggregates (avg, min, max, sum) acting on literals are converted to an iif(count(1)=0, null, literal*count(1)) for sum,
|
||||
* and to iif(count(1)=0,null,literal) for the other three.
|
||||
* Additionally count(DISTINCT literal) is converted to iif(count(1)=0, 0, 1).
|
||||
*/
|
||||
private static class ReplaceAggregatesWithLiterals extends OptimizerRule<LogicalPlan> {
|
||||
|
||||
@Override
|
||||
protected LogicalPlan rule(LogicalPlan p) {
|
||||
return p.transformExpressionsDown(e -> {
|
||||
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum) {
|
||||
NumericAggregate a = (NumericAggregate) e;
|
||||
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum ||
|
||||
(e instanceof Count && ((Count) e).distinct())) {
|
||||
|
||||
AggregateFunction a = (AggregateFunction) e;
|
||||
|
||||
if (a.field().foldable()) {
|
||||
Expression countOne = new Count(a.source(), new Literal(Source.EMPTY, 1, a.dataType()), false);
|
||||
|
@ -799,12 +801,16 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
|
|||
Expression argument = a.field();
|
||||
Literal foldedArgument = new Literal(argument.source(), argument.fold(), a.dataType());
|
||||
|
||||
Expression iifResult = Literal.NULL;
|
||||
Expression iifElseResult = foldedArgument;
|
||||
if (e instanceof Sum) {
|
||||
iifElseResult = new Mul(a.source(), countOne, foldedArgument);
|
||||
} else if (e instanceof Count) {
|
||||
iifResult = new Literal(Source.EMPTY, 0, e.dataType());
|
||||
iifElseResult = new Literal(Source.EMPTY, 1, e.dataType());
|
||||
}
|
||||
|
||||
return new Iif(a.source(), countEqZero, Literal.NULL, iifElseResult);
|
||||
return new Iif(a.source(), countEqZero, iifResult, iifElseResult);
|
||||
}
|
||||
}
|
||||
return e;
|
||||
|
|
|
@ -19,7 +19,12 @@ import java.time.Clock;
|
|||
import java.time.Duration;
|
||||
import java.time.ZonedDateTime;
|
||||
import java.time.ZoneId;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.StringJoiner;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
|
||||
import static org.elasticsearch.test.ESTestCase.randomBoolean;
|
||||
|
@ -97,5 +102,41 @@ public final class SqlTestUtils {
|
|||
}
|
||||
return new Literal(source, value, SqlDataTypes.fromJava(value));
|
||||
}
|
||||
|
||||
public static String randomOrderByAndLimit(int noOfSelectArgs, Random rnd) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (randomBoolean()) {
|
||||
sb.append(" ORDER BY ");
|
||||
|
||||
List<Integer> shuffledArgIndices = IntStream.range(1, noOfSelectArgs + 1).boxed().collect(Collectors.toList());
|
||||
Collections.shuffle(shuffledArgIndices, rnd);
|
||||
for (int i = 0; i < noOfSelectArgs; i++) {
|
||||
sb.append(shuffledArgIndices.get(i));
|
||||
switch (randomInt(2)) {
|
||||
case 0:
|
||||
sb.append(" DESC");
|
||||
break;
|
||||
case 1:
|
||||
sb.append(" ASC");
|
||||
break;
|
||||
}
|
||||
switch (randomInt(2)) {
|
||||
case 0:
|
||||
sb.append(" NULLS FIRST");
|
||||
break;
|
||||
case 1:
|
||||
sb.append(" NULLS LAST");
|
||||
break;
|
||||
}
|
||||
if (i < noOfSelectArgs - 1) {
|
||||
sb.append(", ");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (randomBoolean()) {
|
||||
sb.append(" LIMIT ").append(randomIntBetween(1, 100));
|
||||
}
|
||||
return sb.toString();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -175,31 +175,41 @@ public class QueryFolderTests extends ESTestCase {
|
|||
assertThat(ee.output().get(0).toString(), startsWith("E(){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggs() {
|
||||
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0)");
|
||||
public void testLocalExecWithCount() {
|
||||
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20)" + randomOrderByAndLimit(2));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(SingletonExecutable.class, le.executable().getClass());
|
||||
SingletonExecutable ee = (SingletonExecutable) le.executable();
|
||||
assertEquals(4, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
|
||||
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
|
||||
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
|
||||
assertEquals(2, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggsAndWhereFalseFilter() {
|
||||
PhysicalPlan p = plan("SELECT SUM(10) WHERE 2 > 3");
|
||||
public void testLocalExecWithCountAndWhereFalseFilter() {
|
||||
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 2" + randomOrderByAndLimit(2));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(EmptyExecutable.class, le.executable().getClass());
|
||||
EmptyExecutable ee = (EmptyExecutable) le.executable();
|
||||
assertEquals(1, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("SUM(10){r}#"));
|
||||
assertEquals(2, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggsAndWhereTrueFilter() {
|
||||
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0) WHERE 1 = 1");
|
||||
public void testLocalExecWithCountAndWhereTrueFilter() {
|
||||
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 1" + randomOrderByAndLimit(2));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(SingletonExecutable.class, le.executable().getClass());
|
||||
SingletonExecutable ee = (SingletonExecutable) le.executable();
|
||||
assertEquals(2, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggs() {
|
||||
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30)" + randomOrderByAndLimit(4));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(SingletonExecutable.class, le.executable().getClass());
|
||||
|
@ -207,30 +217,34 @@ public class QueryFolderTests extends ESTestCase {
|
|||
assertEquals(4, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
|
||||
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
|
||||
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
|
||||
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
|
||||
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderBy() {
|
||||
PhysicalPlan p = plan("SELECT MAX(23), SUM(1) WHERE 1 = 1 ORDER BY 1, 2 DESC");
|
||||
public void testLocalExecWithAggsAndWhereFalseFilter() {
|
||||
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 2 > 3" + randomOrderByAndLimit(4));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(EmptyExecutable.class, le.executable().getClass());
|
||||
EmptyExecutable ee = (EmptyExecutable) le.executable();
|
||||
assertEquals(4, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
|
||||
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
|
||||
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggsAndWhereTrueFilter() {
|
||||
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 1 = 1" + randomOrderByAndLimit(4));
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(SingletonExecutable.class, le.executable().getClass());
|
||||
SingletonExecutable ee = (SingletonExecutable) le.executable();
|
||||
assertEquals(2, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("MAX(23){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("SUM(1){r}#"));
|
||||
}
|
||||
|
||||
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderByAndLimit() {
|
||||
PhysicalPlan p = plan("SELECT AVG(10), SUM(2) WHERE 1 = 1 ORDER BY 1, 2 DESC LIMIT 5");
|
||||
assertEquals(LocalExec.class, p.getClass());
|
||||
LocalExec le = (LocalExec) p;
|
||||
assertEquals(SingletonExecutable.class, le.executable().getClass());
|
||||
SingletonExecutable ee = (SingletonExecutable) le.executable();
|
||||
assertEquals(2, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("AVG(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("SUM(2){r}#"));
|
||||
assertEquals(4, ee.output().size());
|
||||
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
|
||||
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
|
||||
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
|
||||
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
|
||||
}
|
||||
|
||||
public void testFoldingOfIsNull() {
|
||||
|
@ -489,4 +503,8 @@ public class QueryFolderTests extends ESTestCase {
|
|||
assertThat(a, containsString("\"terms\":{\"field\":\"keyword\""));
|
||||
assertThat(a, containsString("{\"avg\":{\"field\":\"int\"}"));
|
||||
}
|
||||
|
||||
private static String randomOrderByAndLimit(int noOfSelectArgs) {
|
||||
return SqlTestUtils.randomOrderByAndLimit(noOfSelectArgs, random());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue