fix complex_decode_base64 function, add SQL bindings (#13332)

* fix complex_decode_base64 function, add SQL bindings

* more permissive
This commit is contained in:
Clint Wylie 2022-11-09 23:40:25 -08:00 committed by GitHub
parent 965e41538e
commit 27215d1ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 395 additions and 98 deletions

View File

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.TypeStrategy;
import javax.annotation.Nullable;
import java.util.List;
import java.util.stream.Collectors;
public class BuiltInExprMacros
{
public static class ComplexDecodeBase64ExprMacro implements ExprMacroTable.ExprMacro
{
public static final String NAME = "complex_decode_base64";
@Override
public String name()
{
return NAME;
}
@Override
public Expr apply(List<Expr> args)
{
return new ComplexDecodeBase64Expression(args);
}
final class ComplexDecodeBase64Expression extends ExprMacroTable.BaseScalarMacroFunctionExpr
{
private final ExpressionType complexType;
private final TypeStrategy<?> typeStrategy;
public ComplexDecodeBase64Expression(List<Expr> args)
{
super(NAME, args);
validationHelperCheckArgumentCount(args, 2);
final Expr arg0 = args.get(0);
if (!arg0.isLiteral()) {
throw validationFailed(
"first argument must be constant STRING expression containing a valid complex type name but got '%s' instead",
arg0.stringify()
);
}
if (arg0.isNullLiteral()) {
throw validationFailed("first argument must be constant STRING expression containing a valid complex type name but got NULL instead");
}
final Object literal = arg0.getLiteralValue();
if (!(literal instanceof String)) {
throw validationFailed(
"first argument must be constant STRING expression containing a valid complex type name but got '%s' instead",
arg0.getLiteralValue()
);
}
this.complexType = ExpressionTypeFactory.getInstance().ofComplex((String) literal);
try {
this.typeStrategy = complexType.getStrategy();
}
catch (IllegalArgumentException illegal) {
throw validationFailed(
"first argument must be a valid COMPLEX type name, got unknown COMPLEX type [%s]",
complexType.asTypeString()
);
}
}
@Override
public ExprEval<?> eval(ObjectBinding bindings)
{
ExprEval<?> toDecode = args.get(1).eval(bindings);
if (toDecode.value() == null) {
return ExprEval.ofComplex(complexType, null);
}
final Object serializedValue = toDecode.value();
final byte[] base64;
if (serializedValue instanceof String) {
base64 = StringUtils.decodeBase64String(toDecode.asString());
} else if (serializedValue instanceof byte[]) {
base64 = (byte[]) serializedValue;
} else if (complexType.getComplexTypeName().equals(toDecode.type().getComplexTypeName())) {
// pass it through, it is already the right thing
return toDecode;
} else {
throw validationFailed(
"second argument must be a base64 encoded STRING value but got %s instead",
toDecode.type()
);
}
return ExprEval.ofComplex(complexType, typeStrategy.fromBytes(base64));
}
@Override
public Expr visit(Shuttle shuttle)
{
List<Expr> newArgs = args.stream().map(x -> x.visit(shuttle)).collect(Collectors.toList());
return shuttle.visit(new ComplexDecodeBase64Expression(newArgs));
}
@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
{
return complexType;
}
@Override
public boolean isLiteral()
{
return args.get(1).isLiteral();
}
@Override
public boolean isNullLiteral()
{
return args.get(1).isNullLiteral();
}
@Nullable
@Override
public Object getLiteralValue()
{
return eval(InputBindings.nilBindings()).value();
}
}
}
}

View File

@ -23,6 +23,7 @@ import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.druid.java.util.common.StringUtils;
@ -42,18 +43,18 @@ import java.util.stream.Collectors;
*/
public class ExprMacroTable
{
private static final List<ExprMacro> BUILT_IN = ImmutableList.of(
new BuiltInExprMacros.ComplexDecodeBase64ExprMacro()
);
private static final ExprMacroTable NIL = new ExprMacroTable(Collections.emptyList());
private final Map<String, ExprMacro> macroMap;
public ExprMacroTable(final List<ExprMacro> macros)
{
this.macroMap = macros.stream().collect(
Collectors.toMap(
m -> StringUtils.toLowerCase(m.name()),
m -> m
)
);
this.macroMap = Maps.newHashMapWithExpectedSize(BUILT_IN.size() + macros.size());
macroMap.putAll(BUILT_IN.stream().collect(Collectors.toMap(m -> StringUtils.toLowerCase(m.name()), m -> m)));
macroMap.putAll(macros.stream().collect(Collectors.toMap(m -> StringUtils.toLowerCase(m.name()), m -> m)));
}
public static ExprMacroTable nil()

View File

@ -31,7 +31,6 @@ import org.apache.druid.math.expr.vector.VectorMathProcessors;
import org.apache.druid.math.expr.vector.VectorProcessors;
import org.apache.druid.math.expr.vector.VectorStringProcessors;
import org.apache.druid.segment.column.TypeSignature;
import org.apache.druid.segment.column.TypeStrategy;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.joda.time.format.DateTimeFormat;
@ -39,7 +38,6 @@ import org.joda.time.format.DateTimeFormat;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@ -3683,76 +3681,4 @@ public interface Function extends NamedFunction
return HumanReadableBytes.UnitSystem.DECIMAL;
}
}
class ComplexDecodeBase64Function implements Function
{
@Override
public String name()
{
return "complex_decode_base64";
}
@Override
public ExprEval apply(List<Expr> args, Expr.ObjectBinding bindings)
{
ExprEval arg0 = args.get(0).eval(bindings);
if (!arg0.type().is(ExprType.STRING)) {
throw validationFailed(
"first argument must be constant STRING expression containing a valid complex type name but got %s instead",
arg0.type()
);
}
ExpressionType type = ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
TypeStrategy strategy;
try {
strategy = type.getStrategy();
}
catch (IllegalArgumentException illegal) {
throw validationFailed(
"first argument must be a valid COMPLEX type name, got unknown COMPLEX type [%s]",
type.asTypeString()
);
}
ExprEval base64String = args.get(1).eval(bindings);
if (!base64String.type().is(ExprType.STRING)) {
throw validationFailed(
"second argument must be a base64 encoded STRING value but got %s instead",
base64String.type()
);
}
if (base64String.value() == null) {
return ExprEval.ofComplex(type, null);
}
final byte[] base64 = StringUtils.decodeBase64String(base64String.asString());
return ExprEval.ofComplex(type, strategy.read(ByteBuffer.wrap(base64)));
}
@Override
public void validateArguments(List<Expr> args)
{
validationHelperCheckArgumentCount(args, 2);
if (!args.get(0).isLiteral() || args.get(0).isNullLiteral()) {
throw validationFailed(
"first argument must be constant STRING expression containing a valid COMPLEX type name"
);
}
}
@Nullable
@Override
public ExpressionType getOutputType(
Expr.InputBindingInspector inspector,
List<Expr> args
)
{
ExpressionType arg0Type = args.get(0).getOutputType(inspector);
if (arg0Type == null || !arg0Type.is(ExprType.STRING)) {
throw validationFailed(
"first argument must be constant STRING expression containing a valid COMPLEX type name"
);
}
return ExpressionTypeFactory.getInstance().ofComplex((String) args.get(0).getLiteralValue());
}
}
}

View File

@ -68,7 +68,6 @@ public interface TypeStrategy<T> extends Comparator<T>
*/
int estimateSizeBytes(T value);
/**
* Read a non-null value from the {@link ByteBuffer} at the current {@link ByteBuffer#position()}. This will move
* the underlying position by the size of the value read.
@ -150,4 +149,18 @@ public interface TypeStrategy<T> extends Comparator<T>
buffer.position(oldPosition);
}
}
/**
* Translate raw byte array into a value. This is primarily useful for transforming self contained values that are
* serialized into byte arrays, such as happens with 'COMPLEX' types which serialize to base64 strings in JSON
* responses.
*
* 'COMPLEX' types should implement this method to participate in the expression systems built-in function
* to deserialize base64 encoded values,
* {@link org.apache.druid.math.expr.BuiltInExprMacros.ComplexDecodeBase64ExprMacro}.
*/
default T fromBytes(byte[] value)
{
throw new IllegalStateException("Not supported");
}
}

View File

@ -924,12 +924,25 @@ public class FunctionTest extends InitializedNullHandlingTest
);
}
@Test
public void testComplexDecodeBaseArg0Null()
{
expectedException.expect(ExpressionValidationException.class);
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be constant STRING expression containing a valid complex type name but got NULL instead"
);
assertExpr(
"complex_decode_base64(null, string)",
null
);
}
@Test
public void testComplexDecodeBaseArg0BadType()
{
expectedException.expect(ExpressionValidationException.class);
expectedException.expectMessage(
"Function[complex_decode_base64] first argument must be constant STRING expression containing a valid complex type name but got LONG instead"
"Function[complex_decode_base64] first argument must be constant STRING expression containing a valid complex type name but got '1' instead"
);
assertExpr(
"complex_decode_base64(1, string)",

View File

@ -681,5 +681,11 @@ public class TypeStrategiesTest
}
return written;
}
@Override
public NullableLongPair fromBytes(byte[] value)
{
return read(ByteBuffer.wrap(value));
}
}
}

View File

@ -160,6 +160,16 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
}
if (inputType.is(ValueType.COMPLEX)) {
aggregatorFactory = new HllSketchMergeAggregatorFactory(
aggregatorName,
dimensionSpec.getOutputName(),
logK,
tgtHllType,
finalizeSketch || SketchQueryContext.isFinalizeOuterSketches(plannerContext),
ROUND
);
} else {
aggregatorFactory = new HllSketchBuildAggregatorFactory(
aggregatorName,
dimensionSpec.getDimension(),
@ -169,6 +179,7 @@ public abstract class HllSketchBaseSqlAggregator implements SqlAggregator
ROUND
);
}
}
return toAggregation(
name,

View File

@ -54,6 +54,7 @@ public class ObjectStrategyComplexTypeStrategy<T> implements TypeStrategy<T>
{
final int complexLength = buffer.getInt();
ByteBuffer dupe = buffer.duplicate();
dupe.order(buffer.order());
dupe.limit(dupe.position() + complexLength);
return objectStrategy.fromByteBuffer(dupe, complexLength);
}
@ -85,4 +86,10 @@ public class ObjectStrategyComplexTypeStrategy<T> implements TypeStrategy<T>
{
return objectStrategy.compare(o1, o2);
}
@Override
public T fromBytes(byte[] value)
{
return objectStrategy.fromByteBuffer(ByteBuffer.wrap(value), value.length);
}
}

View File

@ -53,6 +53,7 @@ public class ComplexMetrics
{
COMPLEX_SERIALIZERS.compute(type, (key, value) -> {
if (value == null) {
TypeStrategies.registerComplex(type, serde.getTypeStrategy());
return serde;
} else {
if (!value.getClass().getName().equals(serde.getClass().getName())) {
@ -63,7 +64,6 @@ public class ComplexMetrics
value.getClass().getName()
);
} else {
TypeStrategies.registerComplex(type, serde.getTypeStrategy());
return value;
}
}

View File

@ -121,6 +121,14 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
}
if (inputType.is(ValueType.COMPLEX)) {
aggregatorFactory = new HyperUniquesAggregatorFactory(
aggregatorName,
dimensionSpec.getOutputName(),
false,
true
);
} else {
aggregatorFactory = new CardinalityAggregatorFactory(
aggregatorName,
null,
@ -129,6 +137,7 @@ public class BuiltinApproxCountDistinctSqlAggregator implements SqlAggregator
true
);
}
}
return Aggregation.create(
Collections.singletonList(aggregatorFactory),

View File

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.BuiltInExprMacros;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
import org.apache.druid.sql.calcite.expression.SqlOperatorConversion;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.table.RowSignatures;
import javax.annotation.Nullable;
public class ComplexDecodeBase64OperatorConversion implements SqlOperatorConversion
{
public static final SqlReturnTypeInference ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE = opBinding -> {
String typeName = opBinding.getOperandLiteralValue(0, String.class);
return RowSignatures.makeComplexType(
opBinding.getTypeFactory(),
ColumnType.ofComplex(typeName),
true
);
};
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder(StringUtils.toUpperCase(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME))
.operandTypeChecker(
OperandTypes.sequence(
"(typeName,base64)",
OperandTypes.and(OperandTypes.family(SqlTypeFamily.STRING), OperandTypes.LITERAL),
OperandTypes.ANY
)
)
.returnTypeInference(ARBITRARY_COMPLEX_RETURN_TYPE_INFERENCE)
.functionCategory(SqlFunctionCategory.USER_DEFINED_FUNCTION)
.build();
@Override
public SqlOperator calciteOperator()
{
return SQL_FUNCTION;
}
@Nullable
@Override
public DruidExpression toDruidExpression(
PlannerContext plannerContext,
RowSignature rowSignature,
RexNode rexNode
)
{
return OperatorConversions.convertCall(
plannerContext,
rowSignature,
rexNode,
druidExpressions -> {
String arg0 = druidExpressions.get(0).getExpression();
return DruidExpression.ofExpression(
ColumnType.ofComplex(arg0.substring(1, arg0.length() - 1)),
DruidExpression.functionCall(BuiltInExprMacros.ComplexDecodeBase64ExprMacro.NAME),
druidExpressions
);
}
);
}
}

View File

@ -71,6 +71,7 @@ import org.apache.druid.sql.calcite.expression.builtin.ArrayToStringOperatorConv
import org.apache.druid.sql.calcite.expression.builtin.BTrimOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.CastOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.CeilOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ComplexDecodeBase64OperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ConcatOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ContainsOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.DateTruncOperatorConversion;
@ -211,6 +212,7 @@ public class DruidOperatorTable implements SqlOperatorTable
ImmutableList.<SqlOperatorConversion>builder()
.add(new CastOperatorConversion())
.add(new ReinterpretOperatorConversion())
.add(new ComplexDecodeBase64OperatorConversion())
.build();
private static final List<SqlOperatorConversion> ARRAY_OPERATOR_CONVERSIONS =

View File

@ -14293,4 +14293,71 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
ImmutableList.of()
);
}
@Test
public void testComplexDecode()
{
cannotVectorize();
testQuery(
"SELECT COMPLEX_DECODE_BASE64('hyperUnique',PARSE_JSON(TO_JSON_STRING(unique_dim1))) from druid.foo LIMIT 10",
ImmutableList.of(
Druids.newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("v0")
.virtualColumns(
expressionVirtualColumn(
"v0",
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
ColumnType.ofComplex("hyperUnique")
)
)
.resultFormat(ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.legacy(false)
.limit(10)
.build()
),
ImmutableList.of(
new Object[]{"\"AQAAAEAAAA==\""},
new Object[]{"\"AQAAAQAAAAHNBA==\""},
new Object[]{"\"AQAAAQAAAAOzAg==\""},
new Object[]{"\"AQAAAQAAAAFREA==\""},
new Object[]{"\"AQAAAQAAAACyEA==\""},
new Object[]{"\"AQAAAQAAAAEkAQ==\""}
)
);
}
@Test
public void testComplexDecodeAgg()
{
cannotVectorize();
testQuery(
"SELECT APPROX_COUNT_DISTINCT_BUILTIN(COMPLEX_DECODE_BASE64('hyperUnique',PARSE_JSON(TO_JSON_STRING(unique_dim1)))) from druid.foo",
ImmutableList.of(
Druids.newTimeseriesQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
expressionVirtualColumn(
"v0",
"complex_decode_base64('hyperUnique',parse_json(to_json_string(\"unique_dim1\")))",
ColumnType.ofComplex("hyperUnique")
)
)
.aggregators(
new HyperUniquesAggregatorFactory(
"a0",
"v0",
false,
true
)
)
.build()
),
ImmutableList.of(
new Object[]{6L}
)
);
}
}