From 85ee775390c54f96900d5857241e1b4ad7822318 Mon Sep 17 00:00:00 2001 From: Soumyava <93540295+somu-imply@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:49:21 -0700 Subject: [PATCH] Handling latest_by and earliest_by on numeric columns correctly (#15939) * Handling latest_by and earliest_by on numeric columns correctly * Adding test --- .../first/NumericFirstAggregator.java | 8 ++--- .../first/NumericFirstBufferAggregator.java | 8 ++--- .../last/NumericLastAggregator.java | 9 +++--- .../last/NumericLastBufferAggregator.java | 8 ++--- .../druid/sql/calcite/CalciteQueryTest.java | 30 +++++++++++++++++++ 5 files changed, 47 insertions(+), 16 deletions(-) diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstAggregator.java index b3092377b57..6b32996b4f2 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstAggregator.java @@ -62,10 +62,6 @@ public abstract class NumericFirstAggregator implements Aggregator @Override public void aggregate() { - if (timeSelector.isNull()) { - return; - } - if (needsFoldCheck) { final Object object = valueSelector.getObject(); if (object instanceof SerializablePair) { @@ -84,6 +80,10 @@ public abstract class NumericFirstAggregator implements Aggregator } } + if (timeSelector.isNull()) { + return; + } + long time = timeSelector.getLong(); if (time < firstTime) { firstTime = time; diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstBufferAggregator.java index 4531ee71bcd..f20456d3122 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/first/NumericFirstBufferAggregator.java @@ -97,10 +97,6 @@ public abstract class NumericFirstBufferAggregator implements BufferAggregator @Override public void aggregate(ByteBuffer buf, int position) { - if (timeSelector.isNull()) { - return; - } - long firstTime = buf.getLong(position); if (needsFoldCheck) { final Object object = valueSelector.getObject(); @@ -117,6 +113,10 @@ public abstract class NumericFirstBufferAggregator implements BufferAggregator } } + if (timeSelector.isNull()) { + return; + } + long time = timeSelector.getLong(); if (time < firstTime) { diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastAggregator.java index 159939450ee..50d4470fa54 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastAggregator.java @@ -61,10 +61,6 @@ public abstract class NumericLastAggregator implements Aggregator @Override public void aggregate() { - if (timeSelector.isNull()) { - return; - } - if (needsFoldCheck) { final Object object = valueSelector.getObject(); if (object instanceof SerializablePair) { @@ -83,6 +79,11 @@ public abstract class NumericLastAggregator implements Aggregator return; } } + + if (timeSelector.isNull()) { + return; + } + long time = timeSelector.getLong(); if (time >= lastTime) { lastTime = time; diff --git a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastBufferAggregator.java b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastBufferAggregator.java index 9de6f996887..2ba15a7929d 100644 --- a/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastBufferAggregator.java +++ b/processing/src/main/java/org/apache/druid/query/aggregation/last/NumericLastBufferAggregator.java @@ -100,10 +100,6 @@ public abstract class NumericLastBufferAggregator implements BufferAggregator @Override public void aggregate(ByteBuffer buf, int position) { - if (timeSelector.isNull()) { - return; - } - long lastTime = buf.getLong(position); if (needsFoldCheck) { final Object object = valueSelector.getObject(); @@ -121,6 +117,10 @@ public abstract class NumericLastBufferAggregator implements BufferAggregator } } + if (timeSelector.isNull()) { + return; + } + long time = timeSelector.getLong(); if (time >= lastTime) { diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java index c7de263eca4..8515fb8f9ac 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteQueryTest.java @@ -15404,4 +15404,34 @@ public class CalciteQueryTest extends BaseCalciteQueryTest ImmutableList.of(new Object[]{NullHandling.sqlCompatible() ? 4L : 0L}) ); } + + @Test + public void testLatestByAggregatorOnSecondaryTimestampGroupBy() + { + msqIncompatible(); + testQuery( + "SELECT __time, m1, LATEST_BY(m1, MILLIS_TO_TIMESTAMP(CAST(m2 AS NUMERIC))) from druid.numfoo GROUP BY 1,2", + ImmutableList.of( + new GroupByQuery.Builder() + .setDataSource(CalciteTests.DATASOURCE3) + .setInterval(querySegmentSpec(Filtration.eternity())) + .setGranularity(Granularities.ALL) + .setDimensions( + new DefaultDimensionSpec("__time", "_d0", ColumnType.LONG), + new DefaultDimensionSpec("m1", "_d1", ColumnType.FLOAT) + ) + .setAggregatorSpecs(aggregators(new FloatLastAggregatorFactory("a0", "m1", "m2"))) + .setContext(OUTER_LIMIT_CONTEXT) + .build() + ), + ImmutableList.of( + new Object[]{946684800000L, 1.0F, 1.0F}, + new Object[]{946771200000L, 2.0F, 2.0F}, + new Object[]{946857600000L, 3.0F, 3.0F}, + new Object[]{978307200000L, 4.0F, 4.0F}, + new Object[]{978393600000L, 5.0F, 5.0F}, + new Object[]{978480000000L, 6.0F, 6.0F} + ) + ); + } }