new SCALAR_IN_ARRAY function analogous to DRUID_IN (#16306)

* scalar_in function

* api doc

* refactor
This commit is contained in:
Sree Charan Manamala 2024-04-19 09:45:15 +05:30 committed by GitHub
parent 79e48c6b45
commit ad5701e891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 180 additions and 0 deletions

View File

@ -184,6 +184,7 @@ See javadoc of java.lang.Math for detailed explanation for each function.
| array_ordinal(arr,long) | returns the array element at the 1 based index supplied, or null for an out of range index |
| array_contains(arr,expr) | returns 1 if the array contains the element specified by expr, or contains all elements specified by expr if expr is an array, else 0 |
| array_overlap(arr1,arr2) | returns 1 if arr1 and arr2 have any elements in common, else 0 |
| scalar_in_array(expr, arr) | returns 1 if the scalar is present in the array, else 0 |
| array_offset_of(arr,expr) | returns the 0 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_ordinal_of(arr,expr) | returns the 1 based index of the first occurrence of expr in the array, or `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode) if no matching elements exist in the array. |
| array_prepend(expr,arr) | adds expr to arr at the beginning, the resulting array type determined by the type of the array |

View File

@ -54,6 +54,7 @@ The following table describes array functions. To learn more about array aggrega
|`ARRAY_ORDINAL(arr, long)`|Returns the array element at the 1-based index supplied, or null for an out of range index.|
|`ARRAY_CONTAINS(arr, expr)`|If `expr` is a scalar type, returns 1 if `arr` contains `expr`. If `expr` is an array, returns 1 if `arr` contains all elements of `expr`. Otherwise returns 0.|
|`ARRAY_OVERLAP(arr1, arr2)`|Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
| `SCALAR_IN_ARRAY(expr, arr)`|Returns 1 if the scalar `expr` is present in `arr`. else 0.|
|`ARRAY_OFFSET_OF(arr, expr)`|Returns the 0-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_ORDINAL_OF(arr, expr)`|Returns the 1-based index of the first occurrence of `expr` in the array. If no matching elements exist in the array, returns `null` or `-1` if `druid.generic.useDefaultValueForNull=true` (deprecated legacy mode).|
|`ARRAY_PREPEND(expr, arr)`|Adds `expr` to the beginning of `arr`, the resulting array type determined by the type of `arr`.|

View File

@ -206,6 +206,14 @@ Returns the 1-based index of the first occurrence of `expr` in the array. If no
Returns 1 if `arr1` and `arr2` have any elements in common, else 0.|
## SCALAR_IN_ARRAY
`SCALAR_IN_ARRAY(expr, arr)`
**Function type:** [Array](./sql-array-functions.md)
Returns 1 if the scalar `expr` is present in `arr`, else 0.|
## ARRAY_PREPEND
`ARRAY_PREPEND(expr, arr)`

View File

@ -3724,6 +3724,44 @@ public interface Function extends NamedFunction
}
}
class ArrayScalarInFunction extends ArrayScalarFunction
{
@Override
public String name()
{
return "scalar_in_array";
}
@Nullable
@Override
public ExpressionType getOutputType(Expr.InputBindingInspector inspector, List<Expr> args)
{
return ExpressionType.LONG;
}
@Override
Expr getScalarArgument(List<Expr> args)
{
return args.get(0);
}
@Override
Expr getArrayArgument(List<Expr> args)
{
return args.get(1);
}
@Override
ExprEval doApply(ExprEval arrayExpr, ExprEval scalarExpr)
{
final Object[] array = arrayExpr.castTo(scalarExpr.asArrayType()).asArray();
if (array == null) {
return ExprEval.ofLong(null);
}
return ExprEval.ofLongBoolean(Arrays.asList(array).contains(scalarExpr.value()));
}
}
class ArrayAppendFunction extends ArrayAddElementFunction
{
@Override

View File

@ -369,6 +369,18 @@ public class FunctionTest extends InitializedNullHandlingTest
assertExpr("array_ordinal_of(a, 'baz')", 3L);
}
@Test
public void testScalarInArray()
{
assertExpr("scalar_in_array(2, [1, 2, 3])", 1L);
assertExpr("scalar_in_array(4, [1, 2, 3])", 0L);
assertExpr("scalar_in_array(b, [3, 4])", 0L);
assertExpr("scalar_in_array(1, null)", null);
assertExpr("scalar_in_array(null, null)", null);
assertExpr("scalar_in_array(null, [1, null, 2])", 1L);
assertExpr("scalar_in_array(null, [1, 2])", 0L);
}
@Test
public void testArrayContains()
{

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.sql.calcite.expression.builtin;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
import org.apache.druid.sql.calcite.expression.OperatorConversions;
public class ScalarInArrayOperatorConversion extends DirectOperatorConversion
{
private static final SqlFunction SQL_FUNCTION = OperatorConversions
.operatorBuilder("SCALAR_IN_ARRAY")
.operandTypeChecker(
OperandTypes.sequence(
"'SCALAR_IN_ARRAY(expr, array)'",
OperandTypes.or(
OperandTypes.family(SqlTypeFamily.CHARACTER),
OperandTypes.family(SqlTypeFamily.NUMERIC)
),
OperandTypes.family(SqlTypeFamily.ARRAY)
)
)
.returnTypeInference(ReturnTypes.BOOLEAN_NULLABLE)
.build();
public ScalarInArrayOperatorConversion()
{
super(SQL_FUNCTION, "scalar_in_array");
}
}

View File

@ -110,6 +110,7 @@ import org.apache.druid.sql.calcite.expression.builtin.RepeatOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ReverseOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.RightOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.SafeDivideOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.ScalarInArrayOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.SearchOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.StringFormatOperatorConversion;
import org.apache.druid.sql.calcite.expression.builtin.StringToArrayOperatorConversion;
@ -262,6 +263,7 @@ public class DruidOperatorTable implements SqlOperatorTable
.add(new ArrayToStringOperatorConversion())
.add(new StringToArrayOperatorConversion())
.add(new ArrayToMultiValueStringOperatorConversion())
.add(new ScalarInArrayOperatorConversion())
.build();
private static final List<SqlOperatorConversion> MULTIVALUE_STRING_OPERATOR_CONVERSIONS =

View File

@ -1338,6 +1338,73 @@ public class CalciteArraysQueryTest extends BaseCalciteQueryTest
);
}
@Test
public void testScalarInArrayFilter()
{
msqIncompatible();
testQuery(
"SELECT dim2 FROM druid.numfoo WHERE SCALAR_IN_ARRAY(dim2, ARRAY['a', 'd']) LIMIT 5",
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.filters(
new ExpressionDimFilter("scalar_in_array(\"dim2\",array('a','d'))", ExprMacroTable.nil())
)
.columns("dim2")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.limit(5)
.context(QUERY_CONTEXT_DEFAULT)
.build()
),
ImmutableList.of(
new Object[]{"a"},
new Object[]{"a"}
)
);
}
@Test
public void testArrayScalarInFilter_MVD()
{
msqIncompatible();
testBuilder()
.sql(
"SELECT dim3, (CASE WHEN scalar_in_array(dim3, Array['a', 'b', 'd']) THEN 'abd' ELSE 'not abd' END) " +
"FROM druid.numfoo"
)
.expectedQueries(
ImmutableList.of(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE3)
.intervals(querySegmentSpec(Filtration.eternity()))
.virtualColumns(
new ExpressionVirtualColumn(
"v0",
"case_searched(scalar_in_array(\"dim3\",array('a','b','d')),'abd','not abd')",
ColumnType.STRING,
ExprMacroTable.nil()
)
)
.columns("dim3", "v0")
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(QUERY_CONTEXT_DEFAULT)
.build()
)
)
.expectedResults(ResultMatchMode.RELAX_NULLS,
ImmutableList.of(
new Object[]{"[\"a\",\"b\"]", "[\"abd\",\"abd\"]"},
new Object[]{"[\"b\",\"c\"]", "[\"abd\",\"not abd\"]"},
new Object[]{"d", "abd"},
new Object[]{"", "not abd"},
new Object[]{null, "not abd"},
new Object[]{null, "not abd"}
)
)
.run();
}
@Test
public void testArraySlice()

View File

@ -1508,6 +1508,7 @@ array_overlap
array_prepend
array_slice
array_to_string
scalar_in_array
asin
atan
atan2