Add BIG_SUM SQL function (#13102)

This adds a sql function, "BIG_SUM", that uses
CompressedBigDecimal to do a sum. Other misc changes:

1. handle NumberFormatExceptions when parsing a string (default to set
   to 0, configurable in agg factory to be strict and throw on error)
2. format pom file (whitespace) + add dependency
3. scaleUp -> scale and always require scale as a parameter
This commit is contained in:
Sam Rash 2022-09-26 18:02:25 -07:00 committed by GitHub
parent 1f1fced6d4
commit 28b9edc2a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 752 additions and 176 deletions

View File

@ -18,113 +18,139 @@
~ under the License.
-->
<project
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"
xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.apache.druid</groupId>
<artifactId>druid</artifactId>
<version>25.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<parent>
<groupId>org.apache.druid</groupId>
<artifactId>druid</artifactId>
<version>25.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<groupId>org.apache.druid.extensions.contrib</groupId>
<artifactId>druid-compressed-bigdecimal</artifactId>
<name>druid-compressed-bigdecimal</name>
<groupId>org.apache.druid.extensions.contrib</groupId>
<artifactId>druid-compressed-bigdecimal</artifactId>
<name>druid-compressed-bigdecimal</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-server</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependencies>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-server</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.calcite</groupId>
<artifactId>calcite-core</artifactId>
<version>1.21.0</version>
<scope>provided</scope>
</dependency>
<!-- Tests -->
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
<version>${project.parent.version}</version>
<type>test-jar</type>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>java-hamcrest</artifactId>
<version>2.0.0.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>2.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.inject</groupId>
<artifactId>guice</artifactId>
<version>4.1.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>2.10.5</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>16.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>2.10.2</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.10.2</version>
</dependency>
</dependencies>
<!-- Tests -->
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-core</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
<version>${project.parent.version}</version>
<type>test-jar</type>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-sql</artifactId>
<version>${project.parent.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-server</artifactId>
<version>${project.parent.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.hamcrest</groupId>
<artifactId>java-hamcrest</artifactId>
<version>2.0.0.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
<version>2.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.inject</groupId>
<artifactId>guice</artifactId>
<version>4.1.0</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>2.10.5</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>16.0.1</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>2.10.2</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.10.2</version>
</dependency>
</dependencies>
</project>

View File

@ -29,22 +29,26 @@ public class CompressedBigDecimalAggregator implements Aggregator
{
private final ColumnValueSelector<CompressedBigDecimal> selector;
private final boolean strictNumberParsing;
private final CompressedBigDecimal sum;
/**
* Constructor.
*
* @param size the size to allocate
* @param scale the scale
* @param selector that has the metric value
* @param size the size to allocate
* @param scale the scale
* @param selector that has the metric value
* @param strictNumberParsing true => NumberFormatExceptions thrown; false => NumberFormatException returns 0
*/
public CompressedBigDecimalAggregator(
int size,
int scale,
ColumnValueSelector<CompressedBigDecimal> selector
ColumnValueSelector<CompressedBigDecimal> selector,
boolean strictNumberParsing
)
{
this.selector = selector;
this.strictNumberParsing = strictNumberParsing;
this.sum = ArrayCompressedBigDecimal.allocate(size, scale);
}
@ -54,10 +58,11 @@ public class CompressedBigDecimalAggregator implements Aggregator
@Override
public void aggregate()
{
CompressedBigDecimal selectedObject = Utils.objToCompressedBigDecimal(selector.getObject());
CompressedBigDecimal selectedObject = Utils.objToCompressedBigDecimal(selector.getObject(), strictNumberParsing);
if (selectedObject != null) {
if (selectedObject.getScale() != sum.getScale()) {
selectedObject = Utils.scaleUp(selectedObject, sum.getScale());
selectedObject = Utils.scale(selectedObject, sum.getScale());
}
sum.accumulate(selectedObject);
}

View File

@ -21,6 +21,7 @@ package org.apache.druid.compressedbigdecimal;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Objects;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregateCombiner;
import org.apache.druid.query.aggregation.Aggregator;
@ -48,7 +49,9 @@ public class CompressedBigDecimalAggregatorFactory
{
public static final int DEFAULT_SCALE = 9;
public static final int DEFAULT_SIZE = 3;
public static final int DEFAULT_SIZE = 6;
public static final boolean DEFAULT_STRICT_NUMBER_PARSING = false;
private static final byte CACHE_TYPE_ID = 0x37;
public static final Comparator<CompressedBigDecimal> COMPARATOR = CompressedBigDecimal::compareTo;
@ -57,27 +60,32 @@ public class CompressedBigDecimalAggregatorFactory
private final String fieldName;
private final int size;
private final int scale;
private final boolean strictNumberParsing;
/**
* Constructor.
*
* @param name metric field name
* @param fieldName fieldName metric field name
* @param size size of the int array used for calculations
* @param scale scale of the number
* @param name metric field name
* @param fieldName fieldName metric field name
* @param size size of the int array used for calculations
* @param scale scale of the number
* @param strictNumberParsing if true, failure to parse strings to numbers throws an exception. otherwise 0 is
* returned
*/
@JsonCreator
public CompressedBigDecimalAggregatorFactory(
@JsonProperty("name") String name,
@JsonProperty("fieldName") String fieldName,
@JsonProperty(value = "size", required = false) Integer size,
@JsonProperty(value = "scale", required = false) Integer scale
@JsonProperty(value = "scale", required = false) Integer scale,
@JsonProperty(value = "strictNumberParsing", required = false) Boolean strictNumberParsing
)
{
this.name = name;
this.fieldName = fieldName;
this.size = size == null ? DEFAULT_SIZE : size;
this.scale = scale == null ? DEFAULT_SCALE : scale;
this.strictNumberParsing = strictNumberParsing == null ? DEFAULT_STRICT_NUMBER_PARSING : strictNumberParsing;
}
@SuppressWarnings("unchecked")
@ -88,17 +96,21 @@ public class CompressedBigDecimalAggregatorFactory
}
@Override
protected Aggregator factorize(ColumnSelectorFactory metricFactory,
@Nonnull ColumnValueSelector<CompressedBigDecimal> selector)
protected Aggregator factorize(
ColumnSelectorFactory metricFactory,
@Nonnull ColumnValueSelector<CompressedBigDecimal> selector
)
{
return new CompressedBigDecimalAggregator(size, scale, selector);
return new CompressedBigDecimalAggregator(size, scale, selector, strictNumberParsing);
}
@Override
protected BufferAggregator factorizeBuffered(ColumnSelectorFactory metricFactory,
@Nonnull ColumnValueSelector<CompressedBigDecimal> selector)
protected BufferAggregator factorizeBuffered(
ColumnSelectorFactory metricFactory,
@Nonnull ColumnValueSelector<CompressedBigDecimal> selector
)
{
return new CompressedBigDecimalBufferAggregator(size, scale, selector);
return new CompressedBigDecimalBufferAggregator(size, scale, selector, strictNumberParsing);
}
/* (non-Javadoc)
@ -148,7 +160,7 @@ public class CompressedBigDecimalAggregatorFactory
@Override
public AggregatorFactory getCombiningFactory()
{
return new CompressedBigDecimalAggregatorFactory(name, name, size, scale);
return new CompressedBigDecimalAggregatorFactory(name, name, size, scale, strictNumberParsing);
}
@Override
@ -167,7 +179,8 @@ public class CompressedBigDecimalAggregatorFactory
fieldName,
fieldName,
size,
scale
scale,
strictNumberParsing
));
}
@ -270,6 +283,12 @@ public class CompressedBigDecimalAggregatorFactory
return size;
}
@JsonProperty
public boolean getStrictNumberParsing()
{
return strictNumberParsing;
}
/* (non-Javadoc)
* @see org.apache.druid.query.aggregation.AggregatorFactory#getMaxIntermediateSize()
*/
@ -279,16 +298,40 @@ public class CompressedBigDecimalAggregatorFactory
return Integer.BYTES * size;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
CompressedBigDecimalAggregatorFactory that = (CompressedBigDecimalAggregatorFactory) o;
return size == that.size
&& scale == that.scale
&& Objects.equal(name, that.name)
&& Objects.equal(fieldName, that.fieldName)
&& Objects.equal(strictNumberParsing, that.strictNumberParsing);
}
@Override
public int hashCode()
{
return Objects.hashCode(name, fieldName, size, scale, strictNumberParsing);
}
@Override
public String toString()
{
return "CompressedBigDecimalAggregatorFactory{" +
"name='" + getName() + '\'' +
", type='" + getComplexTypeName() + '\'' +
", fieldName='" + getFieldName() + '\'' +
", requiredFields='" + requiredFields() + '\'' +
", size='" + getSize() + '\'' +
", scale='" + getScale() + '\'' +
'}';
return "CompressedBigDecimalSumAggregatorFactory{" +
"name='" + getName() + '\'' +
", type='" + getComplexTypeName() + '\'' +
", fieldName='" + getFieldName() + '\'' +
", requiredFields='" + requiredFields() + '\'' +
", size='" + getSize() + '\'' +
", scale='" + getScale() + '\'' +
", strictNumberParsing='" + getStrictNumberParsing() + '\'' +
'}';
}
}

View File

@ -35,23 +35,27 @@ public class CompressedBigDecimalBufferAggregator implements BufferAggregator
private final ColumnValueSelector<CompressedBigDecimal> selector;
private final int size;
private final int scale;
private boolean strictNumberParsing;
/**
* Constructor.
*
* @param size the size to allocate
* @param scale the scale
* @param selector a ColumnSelector to retrieve incoming values
*/
* @param size the size to allocate
* @param scale the scale
* @param selector a ColumnSelector to retrieve incoming values
* @param strictNumberParsing true => NumberFormatExceptions thrown; false => NumberFormatException returns 0
* */
public CompressedBigDecimalBufferAggregator(
int size,
int scale,
ColumnValueSelector<CompressedBigDecimal> selector
ColumnValueSelector<CompressedBigDecimal> selector,
boolean strictNumberParsing
)
{
this.selector = selector;
this.size = size;
this.scale = scale;
this.strictNumberParsing = strictNumberParsing;
}
/* (non-Javadoc)
@ -71,7 +75,7 @@ public class CompressedBigDecimalBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
CompressedBigDecimal addend = Utils.objToCompressedBigDecimal(selector.getObject());
CompressedBigDecimal addend = Utils.objToCompressedBigDecimal(selector.getObject(), strictNumberParsing);
if (addend != null) {
Utils.accumulate(buf, position, size, scale, addend);
}

View File

@ -26,6 +26,8 @@ import com.google.common.collect.ImmutableList;
import com.google.inject.Binder;
import org.apache.druid.initialization.DruidModule;
import org.apache.druid.segment.serde.ComplexMetrics;
import org.apache.druid.sql.guice.SqlBindings;
import java.util.List;
/**
@ -37,6 +39,12 @@ public class CompressedBigDecimalModule implements DruidModule
@Override
public void configure(Binder binder)
{
registerSerde();
SqlBindings.addAggregator(binder, CompressedBigDecimalSqlAggregator.class);
}
public static void registerSerde()
{
if (ComplexMetrics.getSerdeForType(COMPRESSED_BIG_DECIMAL) == null) {
ComplexMetrics.registerSerde(COMPRESSED_BIG_DECIMAL, new CompressedBigDecimalMetricSerde());

View File

@ -0,0 +1,206 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.compressedbigdecimal;
import com.google.common.collect.ImmutableList;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.util.Optionality;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.aggregation.Aggregation;
import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.Expressions;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
public class CompressedBigDecimalSqlAggregator implements SqlAggregator
{
private static final SqlAggFunction FUNCTION_INSTANCE = new CompressedBigDecimalSqlAggFunction();
private static final String NAME = "BIG_SUM";
@Override
public SqlAggFunction calciteFunction()
{
return FUNCTION_INSTANCE;
}
@Nullable
@Override
public Aggregation toDruidAggregation(
PlannerContext plannerContext,
RowSignature rowSignature,
VirtualColumnRegistry virtualColumnRegistry,
RexBuilder rexBuilder,
String name,
AggregateCall aggregateCall,
Project project,
List<Aggregation> existingAggregations,
boolean finalizeAggregations
)
{
if (aggregateCall.getArgList().size() < 1) {
return null;
}
// fetch sum column expression
DruidExpression sumColumn = Expressions.toDruidExpression(
plannerContext,
rowSignature,
Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(0)
)
);
if (sumColumn == null) {
return null;
}
String sumColumnName;
if (sumColumn.isDirectColumnAccess()) {
sumColumnName = sumColumn.getDirectColumn();
} else {
sumColumnName =
virtualColumnRegistry.getOrCreateVirtualColumnForExpression(sumColumn, ColumnType.UNKNOWN_COMPLEX);
}
// check if size is provided
Integer size = null;
if (aggregateCall.getArgList().size() >= 2) {
RexNode sizeArg = Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(1)
);
size = ((Number) RexLiteral.value(sizeArg)).intValue();
}
// check if scale is provided
Integer scale = null;
if (aggregateCall.getArgList().size() >= 3) {
RexNode scaleArg = Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(2)
);
scale = ((Number) RexLiteral.value(scaleArg)).intValue();
}
Boolean useStrictNumberParsing = null;
if (aggregateCall.getArgList().size() >= 4) {
RexNode useStrictNumberParsingArg = Expressions.fromFieldAccess(
rowSignature,
project,
aggregateCall.getArgList().get(3)
);
useStrictNumberParsing = RexLiteral.booleanValue(useStrictNumberParsingArg);
}
// create the factory
AggregatorFactory aggregatorFactory = new CompressedBigDecimalAggregatorFactory(
StringUtils.format("%s:agg", name),
sumColumnName,
size,
scale,
useStrictNumberParsing
);
return Aggregation.create(ImmutableList.of(aggregatorFactory), null);
}
private static class CompressedBigDecimalSqlAggFunction extends SqlAggFunction
{
private static final String SIGNATURE2 = "'" + NAME + "'(column, size)";
private static final String SIGNATURE3 = "'" + NAME + "'(column, size, scale)";
private static final String SIGNATURE4 = "'" + NAME + "'(column, size, scale, strictNumberParsing)";
CompressedBigDecimalSqlAggFunction()
{
super(
NAME,
null,
SqlKind.OTHER_FUNCTION,
ReturnTypes.explicit(SqlTypeName.OTHER),
null,
OperandTypes.or(
// first signature is the colum only, BIG_SUM(column)
OperandTypes.and(OperandTypes.ANY, OperandTypes.family(SqlTypeFamily.ANY)),
OperandTypes.and(
OperandTypes.sequence(SIGNATURE2, OperandTypes.ANY, OperandTypes.POSITIVE_INTEGER_LITERAL),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(
SIGNATURE3,
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL
),
OperandTypes.family(SqlTypeFamily.ANY, SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.EXACT_NUMERIC)
),
OperandTypes.and(
OperandTypes.sequence(
SIGNATURE4,
OperandTypes.ANY,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.POSITIVE_INTEGER_LITERAL,
OperandTypes.BOOLEAN
),
OperandTypes.family(
SqlTypeFamily.ANY,
SqlTypeFamily.EXACT_NUMERIC,
SqlTypeFamily.EXACT_NUMERIC,
SqlTypeFamily.BOOLEAN
)
)
),
SqlFunctionCategory.USER_DEFINED_FUNCTION,
false,
false,
Optionality.IGNORED
);
}
}
}

View File

@ -23,6 +23,7 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.data.IndexedInts;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.function.ToIntBiFunction;
@ -103,7 +104,7 @@ public class Utils
}
BufferAccessor accessor = BufferAccessor.prepare(pos);
if (rhs.getScale() != lhsScale) {
rhs = Utils.scaleUp(rhs);
rhs = Utils.scale(rhs, lhsScale);
}
CompressedBigDecimal.internalAdd(
lhsSize,
@ -116,27 +117,17 @@ public class Utils
);
}
/**
* Returns a {@code CompressedBigDecimal} whose scale is moderated as per the default scale.
*
* @param val The value to scale up
* @return Scaled up compressedBigDecimal
*/
public static CompressedBigDecimal scaleUp(CompressedBigDecimal val)
public static CompressedBigDecimal scale(CompressedBigDecimal val, int scale)
{
return new ArrayCompressedBigDecimal(
val.toBigDecimal().setScale(CompressedBigDecimalAggregatorFactory.DEFAULT_SCALE, BigDecimal.ROUND_UP)
);
}
public static CompressedBigDecimal scaleUp(CompressedBigDecimal val, int scale)
{
return new ArrayCompressedBigDecimal(
val.toBigDecimal().setScale(scale, BigDecimal.ROUND_UP)
);
return new ArrayCompressedBigDecimal(val.toBigDecimal().setScale(scale, RoundingMode.UP));
}
public static CompressedBigDecimal objToCompressedBigDecimal(Object obj)
{
return objToCompressedBigDecimal(obj, false);
}
public static CompressedBigDecimal objToCompressedBigDecimal(Object obj, boolean strictNumberParse)
{
CompressedBigDecimal result;
if (obj == null) {
@ -152,7 +143,16 @@ public class Utils
} else if (obj instanceof Float) {
result = new ArrayCompressedBigDecimal(BigDecimal.valueOf((Float) obj));
} else if (obj instanceof String) {
result = new ArrayCompressedBigDecimal(new BigDecimal((String) obj));
try {
result = new ArrayCompressedBigDecimal(new BigDecimal((String) obj));
}
catch (NumberFormatException e) {
if (strictNumberParse) {
throw e;
} else {
result = new ArrayCompressedBigDecimal(0L, 0);
}
}
} else if (obj instanceof CompressedBigDecimal) {
result = (CompressedBigDecimal) obj;
} else {

View File

@ -56,9 +56,15 @@ public class AggregatorCombinerFactoryTest
@Test
public void testCompressedBigDecimalAggregatorFactory()
{
CompressedBigDecimalAggregatorFactory cf = new CompressedBigDecimalAggregatorFactory("name", "fieldName", 9, 0);
CompressedBigDecimalAggregatorFactory cf = new CompressedBigDecimalAggregatorFactory(
"name",
"fieldName",
9,
0,
false
);
Assert.assertEquals(
"CompressedBigDecimalAggregatorFactory{name='name', type='compressedBigDecimal', fieldName='fieldName', requiredFields='[fieldName]', size='9', scale='0'}",
"CompressedBigDecimalSumAggregatorFactory{name='name', type='compressedBigDecimal', fieldName='fieldName', requiredFields='[fieldName]', size='9', scale='0', strictNumberParsing='false'}",
cf.toString()
);
Assert.assertNotNull(cf.getCacheKey());
@ -67,7 +73,7 @@ public class AggregatorCombinerFactoryTest
Assert.assertEquals("5", cf.deserialize(5d).toString());
Assert.assertEquals("5", cf.deserialize("5").toString());
Assert.assertEquals(
"[CompressedBigDecimalAggregatorFactory{name='fieldName', type='compressedBigDecimal', fieldName='fieldName', requiredFields='[fieldName]', size='9', scale='0'}]",
"[CompressedBigDecimalSumAggregatorFactory{name='fieldName', type='compressedBigDecimal', fieldName='fieldName', requiredFields='[fieldName]', size='9', scale='0', strictNumberParsing='false'}]",
Arrays.toString(cf.getRequiredColumns().toArray())
);
Assert.assertEquals("0", cf.combine(null, null).toString());
@ -88,7 +94,13 @@ public class AggregatorCombinerFactoryTest
@Test(expected = RuntimeException.class)
public void testCompressedBigDecimalAggregatorFactoryDeserialize()
{
CompressedBigDecimalAggregatorFactory cf = new CompressedBigDecimalAggregatorFactory("name", "fieldName", 9, 0);
CompressedBigDecimalAggregatorFactory cf = new CompressedBigDecimalAggregatorFactory(
"name",
"fieldName",
9,
0,
false
);
cf.deserialize(5);
}
@ -100,7 +112,7 @@ public class AggregatorCombinerFactoryTest
{
ColumnValueSelector<CompressedBigDecimal> cs = EasyMock.createMock(ColumnValueSelector.class);
ByteBuffer bbuf = ByteBuffer.allocate(10);
CompressedBigDecimalBufferAggregator ca = new CompressedBigDecimalBufferAggregator(4, 0, cs);
CompressedBigDecimalBufferAggregator ca = new CompressedBigDecimalBufferAggregator(4, 0, cs, false);
ca.getFloat(bbuf, 0);
}
@ -112,7 +124,7 @@ public class AggregatorCombinerFactoryTest
{
ColumnValueSelector<CompressedBigDecimal> cs = EasyMock.createMock(ColumnValueSelector.class);
ByteBuffer bbuf = ByteBuffer.allocate(10);
CompressedBigDecimalBufferAggregator ca = new CompressedBigDecimalBufferAggregator(4, 0, cs);
CompressedBigDecimalBufferAggregator ca = new CompressedBigDecimalBufferAggregator(4, 0, cs, false);
ca.getLong(bbuf, 0);
}
@ -174,7 +186,7 @@ public class AggregatorCombinerFactoryTest
public void testCompressedBigDecimalAggregatorGetFloat()
{
ColumnValueSelector cv = EasyMock.createMock(ColumnValueSelector.class);
CompressedBigDecimalAggregator cc = new CompressedBigDecimalAggregator(2, 0, cv);
CompressedBigDecimalAggregator cc = new CompressedBigDecimalAggregator(2, 0, cv, false);
cc.getFloat();
}
@ -185,7 +197,7 @@ public class AggregatorCombinerFactoryTest
public void testCompressedBigDecimalAggregatorGetLong()
{
ColumnValueSelector cv = EasyMock.createMock(ColumnValueSelector.class);
CompressedBigDecimalAggregator cc = new CompressedBigDecimalAggregator(2, 0, cv);
CompressedBigDecimalAggregator cc = new CompressedBigDecimalAggregator(2, 0, cv, false);
cc.getLong();
}
}

View File

@ -166,10 +166,10 @@ public class ArrayCompressedBigDecimalTest
CompressedBigDecimalAggregatorFactory.DEFAULT_SCALE
);
d1.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(Integer.MAX_VALUE)), d1.getScale()));
d2.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(-2L * Integer.MAX_VALUE)), d2.getScale()));
d3.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d3.getScale()));
d4.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d4.getScale()));
d1.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(Integer.MAX_VALUE)), d1.getScale()));
d2.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(-2L * Integer.MAX_VALUE)), d2.getScale()));
d3.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d3.getScale()));
d4.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d4.getScale()));
Assert.assertEquals(-1, d2.compareTo(d1, true));
Assert.assertEquals(1, d1.compareTo(d2, true));
@ -227,10 +227,10 @@ public class ArrayCompressedBigDecimalTest
CompressedBigDecimalAggregatorFactory.DEFAULT_SCALE
);
d1.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(Integer.MAX_VALUE)), d1.getScale()));
d2.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(-8L * Integer.MAX_VALUE)), d2.getScale()));
d3.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d3.getScale()));
d4.accumulate(Utils.scaleUp(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d4.getScale()));
d1.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(Integer.MAX_VALUE)), d1.getScale()));
d2.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(-8L * Integer.MAX_VALUE)), d2.getScale()));
d3.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d3.getScale()));
d4.accumulate(Utils.scale(new ArrayCompressedBigDecimal(new BigDecimal(8L * Integer.MAX_VALUE)), d4.getScale()));
Assert.assertEquals(-1, d2.compareTo(d1, true));
Assert.assertEquals(1, d1.compareTo(d2, true));
@ -560,6 +560,14 @@ public class ArrayCompressedBigDecimalTest
c1.compare(bd, add);
}
@Test
public void testScaleDown()
{
CompressedBigDecimal bd = new ArrayCompressedBigDecimal(new BigDecimal("1.1234567890"));
CompressedBigDecimal scaled = Utils.scale(bd, 9);
Assert.assertEquals("1.123456789", scaled.toString());
}
/**
* Test method for {@link CompressedBigDecimalObjectStrategy
*/

View File

@ -72,7 +72,7 @@ public class CompressedBigDecimalAggregatorGroupByTest
public CompressedBigDecimalAggregatorGroupByTest(GroupByQueryConfig config)
{
CompressedBigDecimalModule module = new CompressedBigDecimalModule();
module.configure(null);
CompressedBigDecimalModule.registerSerde();
helper = AggregationTestHelper.createGroupByQueryAggregationTestHelper(
module.getJacksonModules(), config, tempFolder);
}

View File

@ -67,7 +67,7 @@ public class CompressedBigDecimalAggregatorTimeseriesTest
public CompressedBigDecimalAggregatorTimeseriesTest()
{
CompressedBigDecimalModule module = new CompressedBigDecimalModule();
module.configure(null);
CompressedBigDecimalModule.registerSerde();
helper = AggregationTestHelper.createTimeseriesQueryAggregationTestHelper(
module.getJacksonModules(), tempFolder);
}

View File

@ -0,0 +1,264 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.compressedbigdecimal;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import org.apache.druid.data.input.InputRow;
import org.apache.druid.data.input.impl.DimensionsSpec;
import org.apache.druid.data.input.impl.InputRowParser;
import org.apache.druid.data.input.impl.MapInputRowParser;
import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
import org.apache.druid.data.input.impl.TimestampSpec;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.query.Druids;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.IndexBuilder;
import org.apache.druid.segment.QueryableIndex;
import org.apache.druid.segment.TestHelper;
import org.apache.druid.segment.incremental.IncrementalIndexSchema;
import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
import org.apache.druid.timeline.DataSegment;
import org.apache.druid.timeline.partition.LinearShardSpec;
import org.junit.Test;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class CompressedBigDecimalSqlAggregatorTest extends BaseCalciteQueryTest
{
private static final InputRowParser<Map<String, Object>> PARSER = new MapInputRowParser(
new TimeAndDimsParseSpec(
new TimestampSpec(CalciteTests.TIMESTAMP_COLUMN, "iso", null),
new DimensionsSpec(
DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3", "m2"))
)
)
);
private static final List<InputRow> ROWS1 =
CalciteTests.RAW_ROWS1.stream().map(m -> CalciteTests.createRow(m, PARSER)).collect(Collectors.toList());
@Override
public Iterable<? extends Module> getJacksonModules()
{
CompressedBigDecimalModule bigDecimalModule = new CompressedBigDecimalModule();
return Iterables.concat(super.getJacksonModules(), bigDecimalModule.getJacksonModules());
}
@Override
public SpecificSegmentsQuerySegmentWalker createQuerySegmentWalker() throws IOException
{
CompressedBigDecimalModule bigDecimalModule = new CompressedBigDecimalModule();
for (Module mod : bigDecimalModule.getJacksonModules()) {
CalciteTests.getJsonMapper().registerModule(mod);
TestHelper.JSON_MAPPER.registerModule(mod);
}
QueryableIndex index =
IndexBuilder.create()
.tmpDir(temporaryFolder.newFolder())
.segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
.schema(
new IncrementalIndexSchema.Builder()
.withMetrics(
new CountAggregatorFactory("cnt"),
new DoubleSumAggregatorFactory("m1", "m1")
)
.withRollup(false)
.build()
)
.rows(ROWS1)
.buildMMappedIndex();
walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
DataSegment.builder()
.dataSource(CalciteTests.DATASOURCE1)
.interval(index.getDataInterval())
.version("1")
.shardSpec(new LinearShardSpec(0))
.size(0)
.build(),
index
);
return walker;
}
@Override
public DruidOperatorTable createOperatorTable()
{
return new DruidOperatorTable(ImmutableSet.of(new CompressedBigDecimalSqlAggregator()), ImmutableSet.of());
}
@Override
public ObjectMapper createQueryJsonMapper()
{
ObjectMapper objectMapper = super.createQueryJsonMapper();
objectMapper.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true);
objectMapper.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
return objectMapper;
}
@Test
public void testCompressedBigDecimalAggWithNumberParse1()
{
cannotVectorize();
testQuery(
"SELECT big_sum(m1, 9, 9), big_sum(m2, 9, 9), big_sum(dim1, 9, 9, false) FROM foo",
Collections.singletonList(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
new CompressedBigDecimalAggregatorFactory("a0:agg", "m1", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a1:agg", "m2", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a2:agg", "dim1", 9, 9, false)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{
"21.000000000",
"21.000000000",
"13.100000000",
})
);
}
@Test(expected = NumberFormatException.class)
public void testCompressedBigDecimalAggWithNumberParse2()
{
cannotVectorize();
testQuery(
"SELECT big_sum(dim1, 9, 9, true) FROM foo",
Collections.singletonList(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
new CompressedBigDecimalAggregatorFactory("a0:agg", "dim1", 9, 9, true)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{"13.100000000"})
);
}
@Test
public void testCompressedBigDecimalAggDefaultNumberParse()
{
cannotVectorize();
testQuery(
"SELECT big_sum(m1, 9, 9), big_sum(m2, 9, 9), big_sum(dim1, 9, 9) FROM foo",
Collections.singletonList(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
new CompressedBigDecimalAggregatorFactory("a0:agg", "m1", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a1:agg", "m2", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a2:agg", "dim1", 9, 9, false)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{
"21.000000000",
"21.000000000",
"13.100000000",
})
);
}
@Test
public void testCompressedBigDecimalAggDefaultScale()
{
cannotVectorize();
testQuery(
"SELECT big_sum(m1, 9), big_sum(m2, 9), big_sum(dim1, 9) FROM foo",
Collections.singletonList(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
new CompressedBigDecimalAggregatorFactory("a0:agg", "m1", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a1:agg", "m2", 9, 9, false),
new CompressedBigDecimalAggregatorFactory("a2:agg", "dim1", 9, 9, false)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{
"21.000000000",
"21.000000000",
"13.100000000"
})
);
}
@Test
public void testCompressedBigDecimalAggDefaultSizeAndScale()
{
cannotVectorize();
testQuery(
"SELECT big_sum(m1), big_sum(m2), big_sum(dim1) FROM foo",
Collections.singletonList(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
.granularity(Granularities.ALL)
.aggregators(
new CompressedBigDecimalAggregatorFactory("a0:agg", "m1", 6, 9, false),
new CompressedBigDecimalAggregatorFactory("a1:agg", "m2", 6, 9, false),
new CompressedBigDecimalAggregatorFactory("a2:agg", "dim1", 6, 9, false)
)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(new Object[]{
"21.000000000",
"21.000000000",
"13.100000000"
})
);
}
}

View File

@ -254,7 +254,7 @@ public class CalciteTests
null, null
);
private static final String TIMESTAMP_COLUMN = "t";
public static final String TIMESTAMP_COLUMN = "t";
public static final Injector INJECTOR = new CalciteTestInjectorBuilder().build();