SQL: Fix behaviour of COUNT(DISTINCT <literal>) (#56869) (#56932)

Previously `COUNT(DISTINCT <literal>)` was returning the same result
as `COUNT(<literal>)` which is not correct as it should always return 1
if there is at least one matching row (bucket if there is a GROUP BY),
or 0 otherwise.

(cherry picked from commit 7f7d7562d43034907f432d39d0d66f490d78f4a8)
This commit is contained in:
Marios Trivyzas 2020-05-19 11:19:06 +02:00 committed by GitHub
parent 57c3a61535
commit 644ae49817
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 36 deletions

View File

@ -178,6 +178,8 @@ SELECT gender g, languages l, COUNT(*) c FROM "test_emp" GROUP BY g, l ORDER BY
aggCountDistinctWithAliasAndGroupBy
SELECT COUNT(*) cnt, COUNT(DISTINCT first_name) as names, gender FROM test_emp GROUP BY gender ORDER BY gender;
localCount
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo');
localSum
SELECT CAST(SUM(1) AS BIGINT);
localSumWithAlias
@ -185,7 +187,12 @@ SELECT CAST(SUM(1) AS BIGINT) AS s, CAST(SUM(1) AS BIGINT);
localMax
SELECT MAX(1);
localAggregates
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
SELECT CAST(SUM(1) AS BIGINT), CAST(SUM(123) AS BIGINT), MAX(1), MAX(32), MIN(3), MIN(55+2) AS mn, CAST(AVG(33/3) AS INTEGER) AS av, CAST(AVG(1) AS INTEGER);
countOfLiteralsFromIndex
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp;
countOfLiteralsFromIndexWithGroupBy
SELECT COUNT(1), COUNT(22), COUNT('foo'), COUNT(DISTINCT 1), COUNT(DISTINCT 22), COUNT(DISTINCT 'foo') FROM test_emp GROUP BY gender ORDER BY gender;
aggregatesOfLiteralsFromIndex
SELECT MAX(1), MIN(1), CAST(SUM(1) AS BIGINT), CAST(AVG(1) AS INTEGER), COUNT(1) FROM test_emp;
aggregatesOfLiteralsFromIndex_WithNoMatchingFilter

View File

@ -66,7 +66,6 @@ import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.NumericAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRanks;
@ -784,14 +783,17 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
/**
* Any numeric aggregates (avg, min, max, sum) acting on literals are converted to an iif(count(1)=0, null, literal*count(1)) for sum,
* and to iif(count(1)=0,null,literal) for the other three.
* Additionally count(DISTINCT literal) is converted to iif(count(1)=0, 0, 1).
*/
private static class ReplaceAggregatesWithLiterals extends OptimizerRule<LogicalPlan> {
@Override
protected LogicalPlan rule(LogicalPlan p) {
return p.transformExpressionsDown(e -> {
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum) {
NumericAggregate a = (NumericAggregate) e;
if (e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum ||
(e instanceof Count && ((Count) e).distinct())) {
AggregateFunction a = (AggregateFunction) e;
if (a.field().foldable()) {
Expression countOne = new Count(a.source(), new Literal(Source.EMPTY, 1, a.dataType()), false);
@ -799,12 +801,16 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Expression argument = a.field();
Literal foldedArgument = new Literal(argument.source(), argument.fold(), a.dataType());
Expression iifResult = Literal.NULL;
Expression iifElseResult = foldedArgument;
if (e instanceof Sum) {
iifElseResult = new Mul(a.source(), countOne, foldedArgument);
} else if (e instanceof Count) {
iifResult = new Literal(Source.EMPTY, 0, e.dataType());
iifElseResult = new Literal(Source.EMPTY, 1, e.dataType());
}
return new Iif(a.source(), countEqZero, Literal.NULL, iifElseResult);
return new Iif(a.source(), countEqZero, iifResult, iifElseResult);
}
}
return e;

View File

@ -19,7 +19,12 @@ import java.time.Clock;
import java.time.Duration;
import java.time.ZonedDateTime;
import java.time.ZoneId;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.StringJoiner;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength;
import static org.elasticsearch.test.ESTestCase.randomBoolean;
@ -97,5 +102,41 @@ public final class SqlTestUtils {
}
return new Literal(source, value, SqlDataTypes.fromJava(value));
}
public static String randomOrderByAndLimit(int noOfSelectArgs, Random rnd) {
StringBuilder sb = new StringBuilder();
if (randomBoolean()) {
sb.append(" ORDER BY ");
List<Integer> shuffledArgIndices = IntStream.range(1, noOfSelectArgs + 1).boxed().collect(Collectors.toList());
Collections.shuffle(shuffledArgIndices, rnd);
for (int i = 0; i < noOfSelectArgs; i++) {
sb.append(shuffledArgIndices.get(i));
switch (randomInt(2)) {
case 0:
sb.append(" DESC");
break;
case 1:
sb.append(" ASC");
break;
}
switch (randomInt(2)) {
case 0:
sb.append(" NULLS FIRST");
break;
case 1:
sb.append(" NULLS LAST");
break;
}
if (i < noOfSelectArgs - 1) {
sb.append(", ");
}
}
}
if (randomBoolean()) {
sb.append(" LIMIT ").append(randomIntBetween(1, 100));
}
return sb.toString();
}
}

View File

@ -175,31 +175,41 @@ public class QueryFolderTests extends ESTestCase {
assertThat(ee.output().get(0).toString(), startsWith("E(){r}#"));
}
public void testLocalExecWithAggs() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0)");
public void testLocalExecWithCount() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20)" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}
public void testLocalExecWithAggsAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT SUM(10) WHERE 2 > 3");
public void testLocalExecWithCountAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 2" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(EmptyExecutable.class, le.executable().getClass());
EmptyExecutable ee = (EmptyExecutable) le.executable();
assertEquals(1, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("SUM(10){r}#"));
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(1), AVG(0) WHERE 1 = 1");
public void testLocalExecWithCountAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT COUNT(10), COUNT(DISTINCT 20) WHERE 1 = 1" + randomOrderByAndLimit(2));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("COUNT(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("COUNT(DISTINCT 20){r}#"));
}
public void testLocalExecWithAggs() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30)" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
@ -207,30 +217,34 @@ public class QueryFolderTests extends ESTestCase {
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(1){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(0){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderBy() {
PhysicalPlan p = plan("SELECT MAX(23), SUM(1) WHERE 1 = 1 ORDER BY 1, 2 DESC");
public void testLocalExecWithAggsAndWhereFalseFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 2 > 3" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(EmptyExecutable.class, le.executable().getClass());
EmptyExecutable ee = (EmptyExecutable) le.executable();
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilter() {
PhysicalPlan p = plan("SELECT MIN(10), MAX(123), SUM(20), AVG(30) WHERE 1 = 1" + randomOrderByAndLimit(4));
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MAX(23){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(1){r}#"));
}
public void testLocalExecWithAggsAndWhereTrueFilterAndOrderByAndLimit() {
PhysicalPlan p = plan("SELECT AVG(10), SUM(2) WHERE 1 = 1 ORDER BY 1, 2 DESC LIMIT 5");
assertEquals(LocalExec.class, p.getClass());
LocalExec le = (LocalExec) p;
assertEquals(SingletonExecutable.class, le.executable().getClass());
SingletonExecutable ee = (SingletonExecutable) le.executable();
assertEquals(2, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("AVG(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("SUM(2){r}#"));
assertEquals(4, ee.output().size());
assertThat(ee.output().get(0).toString(), startsWith("MIN(10){r}#"));
assertThat(ee.output().get(1).toString(), startsWith("MAX(123){r}#"));
assertThat(ee.output().get(2).toString(), startsWith("SUM(20){r}#"));
assertThat(ee.output().get(3).toString(), startsWith("AVG(30){r}#"));
}
public void testFoldingOfIsNull() {
@ -489,4 +503,8 @@ public class QueryFolderTests extends ESTestCase {
assertThat(a, containsString("\"terms\":{\"field\":\"keyword\""));
assertThat(a, containsString("{\"avg\":{\"field\":\"int\"}"));
}
private static String randomOrderByAndLimit(int noOfSelectArgs) {
return SqlTestUtils.randomOrderByAndLimit(noOfSelectArgs, random());
}
}