Synch more functionality

Original commit: elastic/x-pack-elasticsearch@093c275b85
This commit is contained in:
Costin Leau 2017-06-30 20:41:47 +03:00
parent 0e8ef06947
commit be2153851c
76 changed files with 2191 additions and 935 deletions

View File

@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.io.Reader;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Properties;
import java.util.function.Supplier;
import org.elasticsearch.xpack.sql.jdbc.integration.util.JdbcTemplate.JdbcSupplier;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.relique.jdbc.csv.CsvDriver;
@RunWith(Suite.class)
public abstract class CsvInfraSuite extends EsInfra {
private static CsvDriver DRIVER = new CsvDriver();
public static final Map<Connection, Reader> CSV_READERS = new LinkedHashMap<>();
@BeforeClass
public static void setupDB() throws Exception {
EsInfra.setupDB();
}
@AfterClass
public static void cleanup() throws Exception {
CSV_READERS.clear();
}
public static Supplier<Connection> csvCon(Properties props, Reader reader) {
return new JdbcSupplier<Connection>() {
@Override
public Connection jdbc() throws SQLException {
Connection con = DRIVER.connect("jdbc:relique:csv:class:" + CsvSpecTableReader.class.getName(), props);
CSV_READERS.put(con, reader);
return con;
}
};
}
}

View File

@ -0,0 +1,115 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.io.Reader;
import java.io.StringReader;
import java.sql.Connection;
import java.sql.ResultSet;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import org.junit.Test;
import org.junit.runners.Parameterized.Parameter;
import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.jdbc.integration.util.JdbcAssert.assertResultSets;
public abstract class CsvSpecBaseTest extends SpecBaseTest {
@Parameter(3)
public CsvFragment fragment;
protected static List<Object[]> readScriptSpec(String url) throws Exception {
return SpecBaseTest.readScriptSpec(url, new CsvSpecParser());
}
@Test
public void testQuery() throws Throwable {
try (Connection csv = CsvInfraSuite.csvCon(fragment.asProps(), fragment.reader).get();
Connection es = CsvInfraSuite.esCon().get()) {
ResultSet expected, actual;
try {
// pass the testName as table for debugging purposes (in case the underlying reader is missing)
expected = csv.createStatement(ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY).executeQuery("SELECT * FROM " + testName);
// trigger data loading for type inference
expected.beforeFirst();
actual = es.createStatement().executeQuery(fragment.query);
assertResultSets(expected, actual);
} catch (AssertionError ae) {
throw reworkException(new AssertionError(errorMessage(ae), ae.getCause()));
}
} catch (Throwable th) {
throw new RuntimeException(errorMessage(th), th);
}
}
String errorMessage(Throwable th) {
return format(Locale.ROOT, "test%s@%s:%d failed\n\"%s\"\n%s", testName, source.getFileName().toString(), lineNumber, fragment.query, th.getMessage());
}
}
class CsvSpecParser implements SpecBaseTest.Parser {
private final StringBuilder data = new StringBuilder();
private CsvFragment fragment;
@Override
public Object parse(String line) {
// beginning of the section
if (fragment == null) {
// pick up the query
fragment = new CsvFragment();
fragment.query = line.endsWith(";") ? line.substring(0, line.length() - 1) : line;
}
else {
// read CSV header
// if (fragment.columnNames == null) {
// fragment.columnNames = line;
// }
// read data
if (line.startsWith(";")) {
CsvFragment f = fragment;
f.reader = new StringReader(data.toString());
// clean-up
fragment = null;
data.setLength(0);
return f;
}
else {
data.append(line);
data.append("\r\n");
}
}
return null;
}
}
class CsvFragment {
String query;
String columnNames;
List<String> columnTypes;
Reader reader;
private static final Properties DEFAULT = new Properties();
static {
DEFAULT.setProperty("charset", "UTF-8");
// trigger auto-detection
DEFAULT.setProperty("columnTypes", "");
DEFAULT.setProperty("separator", "|");
DEFAULT.setProperty("trimValues", "true");
}
Properties asProps() {
// p.setProperty("suppressHeaders", "true");
// p.setProperty("headerline", columnNames);
return DEFAULT;
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.io.Reader;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import org.relique.io.TableReader;
public class CsvSpecTableReader implements TableReader {
@Override
public Reader getReader(Statement statement, String tableName) throws SQLException {
Reader reader = CsvInfraSuite.CSV_READERS.remove(statement.getConnection());
if (reader == null) {
throw new RuntimeException("Cannot find reader for test " + tableName);
}
return reader;
}
@Override
public List<String> getTableNames(Connection connection) throws SQLException {
throw new UnsupportedOperationException();
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.sql.Connection;
import java.util.function.Supplier;
import org.elasticsearch.xpack.sql.jdbc.integration.util.EsDataLoader;
import org.elasticsearch.xpack.sql.jdbc.integration.util.EsJdbcServer;
import org.elasticsearch.xpack.sql.jdbc.integration.util.JdbcTemplate;
import org.junit.ClassRule;
import static org.junit.Assert.assertNotNull;
public class EsInfra {
//
// REMOTE ACCESS
//
private static boolean REMOTE = true;
@ClassRule
public static EsJdbcServer ES_JDBC_SERVER = new EsJdbcServer(REMOTE, false);
private static JdbcTemplate ES_JDBC;
public static void setupDB() throws Exception {
//ES_CON = new JdbcTemplate(ES_JDBC_SERVER);
if (!REMOTE) {
setupES();
}
}
private static void setupES() throws Exception {
EsDataLoader.loadData();
}
public static Supplier<Connection> esCon() {
return ES_JDBC_SERVER;
}
public static JdbcTemplate es() {
assertNotNull("ES connection null - make sure the suite is ran", ES_JDBC);
return ES_JDBC;
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.elasticsearch.common.Strings;
import org.junit.Assert;
import org.junit.runners.Parameterized.Parameter;
import static java.lang.String.format;
public abstract class SpecBaseTest {
@Parameter(0)
public String testName;
@Parameter(1)
public Integer lineNumber;
@Parameter(2)
public Path source;
interface Parser {
Object parse(String line);
}
// returns testName, its line location, its source and the custom object (based on each test parser)
protected static List<Object[]> readScriptSpec(String url, Parser parser) throws Exception {
Path source = Paths.get(SpecBaseTest.class.getResource(url).toURI());
List<String> lines = Files.readAllLines(source);
Map<String, Integer> testNames = new LinkedHashMap<>();
List<Object[]> pairs = new ArrayList<>();
String name = null;
for (int i = 0; i < lines.size(); i++) {
String line = lines.get(i).trim();
// ignore comments
if (!line.isEmpty() && !line.startsWith("//")) {
// parse test name
if (name == null) {
if (testNames.keySet().contains(line)) {
throw new IllegalStateException(format(Locale.ROOT, "Duplicate test name '%s' at line %d (previously seen at line %d)", line, i, testNames.get(line)));
}
else {
name = Strings.capitalize(line);
testNames.put(name, Integer.valueOf(i));
}
}
else {
Object result = parser.parse(line);
// only if the parser is ready, add the object - otherwise keep on serving it lines
if (result != null) {
pairs.add(new Object[] { name, Integer.valueOf(i), source, result });
name = null;
}
}
}
}
Assert.assertNull("Cannot find spec for test " + name, name);
return pairs;
}
Throwable reworkException(Throwable th) {
StackTraceElement[] stackTrace = th.getStackTrace();
StackTraceElement[] redone = new StackTraceElement[stackTrace.length + 1];
System.arraycopy(stackTrace, 0, redone, 1, stackTrace.length);
redone[0] = new StackTraceElement(getClass().getName(), testName, source.getFileName().toString(), lineNumber);
th.setStackTrace(redone);
return th;
}
}

View File

@ -0,0 +1,49 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.sql.Connection;
import java.util.function.Supplier;
import org.elasticsearch.xpack.sql.jdbc.integration.util.H2;
import org.elasticsearch.xpack.sql.jdbc.integration.util.JdbcTemplate;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import static org.junit.Assert.assertNotNull;
@RunWith(Suite.class)
public abstract class SqlInfraSuite extends EsInfra {
private static String REMOTE_H2 = "jdbc:h2:tcp://localhost/./essql";
@ClassRule
public static H2 H2 = new H2(null);
private static JdbcTemplate H2_JDBC;
@BeforeClass
public static void setupDB() throws Exception {
H2_JDBC = new JdbcTemplate(H2);
setupH2();
EsInfra.setupDB();
}
private static void setupH2() throws Exception {
h2().execute("RUNSCRIPT FROM 'classpath:org/elasticsearch/sql/jdbc/integration/h2-setup.sql'");
}
public static Supplier<Connection> h2Con() {
return H2;
}
public static JdbcTemplate h2() {
assertNotNull("H2 connection null - make sure the suite is ran", H2_JDBC);
return H2_JDBC;
}
}

View File

@ -0,0 +1,58 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.jdbc.integration.util.framework;
import java.sql.Connection;
import java.sql.ResultSet;
import java.util.List;
import java.util.Locale;
import org.junit.Test;
import org.junit.runners.Parameterized.Parameter;
import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.jdbc.integration.util.JdbcAssert.assertResultSets;
public abstract class SqlSpecBaseTest extends SpecBaseTest {
@Parameter(3)
public String query;
protected static List<Object[]> readScriptSpec(String url) throws Exception {
return SpecBaseTest.readScriptSpec(url, new SqlSpecParser());
}
@Test
public void testQuery() throws Throwable {
// H2 resultset
try (Connection h2 = SqlInfraSuite.h2Con().get();
Connection es = SqlInfraSuite.esCon().get()) {
ResultSet expected, actual;
try {
expected = h2.createStatement().executeQuery(query);
actual = es.createStatement().executeQuery(query);
assertResultSets(expected, actual);
} catch (AssertionError ae) {
throw reworkException(new AssertionError(errorMessage(ae), ae.getCause()));
}
} catch (Throwable th) {
throw reworkException(new RuntimeException(errorMessage(th)));
}
}
String errorMessage(Throwable th) {
return format(Locale.ROOT, "test%s@%s:%d failed\n\"%s\"\n%s", testName, source.getFileName().toString(), lineNumber, query, th.getMessage());
}
}
class SqlSpecParser implements SpecBaseTest.Parser {
@Override
public Object parse(String line) {
return line.endsWith(";") ? line.substring(0, line.length() - 1) : line;
}
}

View File

@ -0,0 +1,22 @@
// some comment
// name of the test - translated into 'testName'
name
// ES SQL query
SELECT COUNT(*) FROM "emp.emp";
//
// expected result in CSV format
//
// list of <ColumnName:ColumnType*>
// type might be missing in which case it will be autodetected or can be one of the following
// d - double, f - float, i - int, b - byte, l - long, t - timestamp, date
A,B:d,C:i
// actual values
foo,2.5,3
bar,3.5,4
tar,4.5,5
;
// repeat the above

View File

@ -0,0 +1,8 @@
// some comment
// name of the test - translated into 'testName'
name
// SQL query to be executed against H2 and ES
SELECT COUNT(*) FROM "emp.emp";
// repeat the above

View File

@ -5,8 +5,11 @@ description = 'The server components of SQL for Elasticsearch'
dependencies {
compile project(':x-pack-elasticsearch:sql:jdbc-proto')
compile project(':x-pack-elasticsearch:sql:cli-proto')
provided "org.elasticsearch.plugin:aggs-matrix-stats-client:${project.versions.elasticsearch}"
//NOCOMMIT - we should upgrade to the latest 4.5.x if not 4.7
compile 'org.antlr:antlr4-runtime:4.5.1-1'
provided "org.elasticsearch:elasticsearch:${project.versions.elasticsearch}"
}
dependencyLicenses {

View File

@ -5,15 +5,6 @@
*/
package org.elasticsearch.xpack.sql.analysis.analyzer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import org.elasticsearch.xpack.sql.analysis.AnalysisException;
import org.elasticsearch.xpack.sql.analysis.UnknownFunctionException;
import org.elasticsearch.xpack.sql.analysis.UnknownIndexException;
@ -60,11 +51,20 @@ import org.elasticsearch.xpack.sql.tree.Node;
import org.elasticsearch.xpack.sql.type.CompoundDataType;
import org.elasticsearch.xpack.sql.util.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Stream;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.elasticsearch.xpack.sql.util.CollectionUtils.combine;
public class Analyzer extends RuleExecutor<LogicalPlan> {
@ -97,7 +97,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
return Arrays.asList(substitution, resolution);
}
public LogicalPlan analyze(LogicalPlan plan) {
return analyze(plan, true);
}
@ -467,14 +467,12 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
if (ordinal > 0 && ordinal <= max) {
NamedExpression reference = aggregates.get(ordinal);
if (containsAggregate(reference)) {
throw new AnalysisException(exp,
"Group ordinal %d refers to an aggregate function %s which is not compatible/allowed with GROUP BY", ordinal, reference.nodeName());
throw new AnalysisException(exp, "Group ordinal %d refers to an aggregate function %s which is not compatible/allowed with GROUP BY", ordinal, reference.nodeName());
}
newGroupings.add(reference);
}
else {
throw new AnalysisException(exp,
"Invalid ordinal %d specified in Aggregate (valid range is [1, %d])", ordinal, max);
throw new AnalysisException(exp, "Invalid ordinal %d specified in Aggregate (valid range is [1, %d])", ordinal, max);
}
}
else {
@ -581,21 +579,18 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
if (plan instanceof Aggregate) {
Aggregate a = (Aggregate) plan;
// missing attributes can only be grouping expressions
for (Expression g : a.groupings()) {
for (Attribute m : missing) {
if (!g.anyMatch(e -> e.canonicalEquals(m))) {
// no match - bail out
return a;
}
for (Attribute m : missing) {
// but we don't can't add an agg if the group is missing
if (!Expressions.anyMatchInList(a.groupings(), g -> g.canonicalEquals(m))) {
// we cannot propagate the missing attribute, bail out
//throw new AnalysisException(logicalPlan, "Cannot add missing attribute %s to %s", m.name(), plan);
return plan;
}
}
return new Aggregate(a.location(), a.child(), a.groupings(), combine(a.aggregates(), missing));
}
return plan;
// we cannot propagate the missing attribute, bail out
//throw new AnalysisException(format("Cannot add missing attribute %s to node %s", missing, plan), plan);
});
}
}
@ -644,7 +639,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
}
// TODO: might be removed
// dedicated count optimization
if (uf.name().equals("COUNT")) {
if (name.toUpperCase(Locale.ROOT).equals("COUNT")) {
uf = new UnresolvedFunction(uf.location(), uf.name(), uf.distinct(), singletonList(Literal.of(uf.arguments().get(0).location(), Integer.valueOf(1))));
}
}
@ -969,7 +964,7 @@ public class Analyzer extends RuleExecutor<LogicalPlan> {
}
}
abstract static class AnalyzeRule<SubPlan extends LogicalPlan> extends Rule<SubPlan, LogicalPlan> {
static abstract class AnalyzeRule<SubPlan extends LogicalPlan> extends Rule<SubPlan, LogicalPlan> {
// transformUp (post-order) - that is first children and then the node
// but with a twist; only if the tree is not resolved or analyzed

View File

@ -93,7 +93,7 @@ abstract class Verifier {
//
// first look at expressions
p.forEachExpressionsUp(e -> e.forEachUp(ae -> {
p.forEachExpressions(e -> e.forEachUp(ae -> {
if (ae.typeResolved().unresolved()) {
localFailures.add(fail(ae, ae.typeResolved().message()));
}

View File

@ -5,8 +5,6 @@
*/
package org.elasticsearch.xpack.sql.analysis.catalog;
import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetaData;
@ -14,7 +12,6 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.MappingMetaData;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.regex.Regex;
import java.util.ArrayList;
import java.util.Collection;
@ -105,28 +102,30 @@ public class EsCatalog implements Catalog {
}
@Override
public Collection<EsType> listTypes(String indexPattern, String typePattern) {
public Collection<EsType> listTypes(String indexPattern, String pattern) {
if (!Strings.hasText(indexPattern)) {
indexPattern = WILDCARD;
}
String[] indices = indexNameExpressionResolver.concreteIndexNames(clusterState.get(),
IndicesOptions.strictExpandOpenAndForbidClosed(), indexPattern);
String[] iName = indexNameExpressionResolver.concreteIndexNames(clusterState.get(), IndicesOptions.strictExpandOpenAndForbidClosed(), indexPattern);
List<EsType> types = new ArrayList<>();
List<EsType> results = new ArrayList<>();
for (String index : indices) {
IndexMetaData imd = metadata().index(index);
for (ObjectObjectCursor<String, MappingMetaData> entry : imd.getMappings()) {
if (false == Strings.hasLength(typePattern) || Regex.simpleMatch(typePattern, entry.key)) {
results.add(EsType.build(index, entry.key, entry.value));
}
for (String cIndex : iName) {
IndexMetaData imd = metadata().index(cIndex);
if (Strings.hasText(pattern)) {
types.add(EsType.build(cIndex, pattern, imd.mapping(pattern)));
}
else {
types.addAll(EsType.build(cIndex, imd.getMappings()));
}
}
return results;
return types;
}
private String[] resolveIndex(String pattern) {
return indexNameExpressionResolver.concreteIndexNames(clusterState.get(), IndicesOptions.strictExpandOpenAndForbidClosed(),
pattern);
return indexNameExpressionResolver.concreteIndexNames(clusterState.get(), IndicesOptions.strictExpandOpenAndForbidClosed(), pattern);
}
}

View File

@ -5,12 +5,17 @@
*/
package org.elasticsearch.xpack.sql.analysis.catalog;
import org.elasticsearch.ElasticsearchParseException;
import com.carrotsearch.hppc.cursors.ObjectObjectCursor;
import org.elasticsearch.cluster.metadata.MappingMetaData;
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.Types;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class EsType {
@ -40,17 +45,22 @@ public class EsType {
}
static EsType build(String index, String type, MappingMetaData metaData) {
Map<String, Object> asMap;
try {
asMap = metaData.sourceAsMap();
} catch (ElasticsearchParseException ex) {
throw new MappingException("Cannot get mapping info", ex);
}
Map<String, Object> asMap = metaData.sourceAsMap();
Map<String, DataType> mapping = Types.fromEs(asMap);
return new EsType(index, type, mapping);
}
static Collection<EsType> build(String index, ImmutableOpenMap<String, MappingMetaData> mapping) {
List<EsType> tps = new ArrayList<>();
for (ObjectObjectCursor<String, MappingMetaData> entry : mapping) {
tps.add(build(index, entry.key, entry.value));
}
return tps;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();

View File

@ -12,8 +12,9 @@ import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static java.util.Collections.emptyList;
import org.elasticsearch.xpack.sql.util.StringUtils;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.toList;
public abstract class InMemoryCatalog implements Catalog {
@ -53,7 +54,7 @@ public abstract class InMemoryCatalog implements Catalog {
@Override
public Collection<EsIndex> listIndices(String pattern) {
Pattern p = Pattern.compile(pattern);
Pattern p = StringUtils.likeRegex(pattern);
return indices.entrySet().stream()
.filter(e -> p.matcher(e.getKey()).matches())
.map(Map.Entry::getValue)
@ -89,7 +90,7 @@ public abstract class InMemoryCatalog implements Catalog {
return emptyList();
}
Pattern p = Pattern.compile(pattern);
Pattern p = StringUtils.likeRegex(pattern);
return typs.entrySet().stream()
.filter(e -> p.matcher(e.getKey()).matches())
.map(Map.Entry::getValue)

View File

@ -6,14 +6,14 @@
package org.elasticsearch.xpack.sql.execution.search;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
class ProcessingHitExtractor implements HitExtractor {
final HitExtractor delegate;
private final ColumnsProcessor processor;
private final ColumnProcessor processor;
ProcessingHitExtractor(HitExtractor delegate, ColumnsProcessor processor) {
ProcessingHitExtractor(HitExtractor delegate, ColumnProcessor processor) {
this.delegate = delegate;
this.processor = processor;
}

View File

@ -5,10 +5,6 @@
*/
package org.elasticsearch.xpack.sql.execution.search;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
@ -27,8 +23,8 @@ import org.elasticsearch.search.aggregations.support.AggregationPath;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.execution.ExecutionException;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.querydsl.agg.Agg;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.querydsl.agg.AggPath;
import org.elasticsearch.xpack.sql.querydsl.container.AggRef;
import org.elasticsearch.xpack.sql.querydsl.container.NestedFieldRef;
import org.elasticsearch.xpack.sql.querydsl.container.ProcessingRef;
@ -42,6 +38,10 @@ import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.type.Schema;
import org.elasticsearch.xpack.sql.util.ObjectUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
// TODO: add retry/back-off
public class Scroller {
@ -87,268 +87,273 @@ public class Scroller {
ScrollerActionListener l = new SessionScrollActionListener(listener, previous.client, previous.keepAlive, previous.schema, ext, previous.limit, previous.docsRead);
previous.client.searchScroll(new SearchScrollRequest(scrollId).scroll(previous.keepAlive), l);
}
}
/**
* Dedicated scroll used for aggs-only/group-by results.
*/
static class AggsScrollActionListener extends ScrollerActionListener {
private final QueryContainer query;
// dedicated scroll used for aggs-only/group-by results
class AggsScrollActionListener extends ScrollerActionListener {
AggsScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, QueryContainer query) {
super(listener, client, keepAlive, schema);
this.query = query;
}
private final QueryContainer query;
@Override
protected RowSetCursor handleResponse(SearchResponse response) {
Aggregations aggs = response.getAggregations();
List<Object[]> columns = new ArrayList<>();
// this method assumes the nested aggregation are all part of the same tree (the SQL group-by)
int maxDepth = -1;
for (Reference ref : query.refs()) {
Object[] arr = null;
ColumnsProcessor processor = null;
if (ref instanceof ProcessingRef) {
ProcessingRef pRef = (ProcessingRef) ref;
processor = pRef.processor();
ref = pRef.ref();
}
if (ref == TotalCountRef.INSTANCE) {
arr = new Object[] { processIfNeeded(processor, Long.valueOf(response.getHits().getTotalHits())) };
columns.add(arr);
}
else if (ref instanceof AggRef) {
// workaround for elastic/elasticsearch/issues/23056
String path = ((AggRef) ref).path();
boolean formattedKey = path.endsWith(Agg.PATH_BUCKET_VALUE_FORMATTED);
if (formattedKey) {
path = path.substring(0, path.length() - Agg.PATH_BUCKET_VALUE_FORMATTED.length());
}
Object value = getAggProperty(aggs, path);
if (formattedKey) {
List<? extends Bucket> buckets = ((MultiBucketsAggregation) value).getBuckets();
arr = new Object[buckets.size()];
for (int i = 0; i < buckets.size(); i++) {
arr[i] = buckets.get(i).getKeyAsString();
}
}
else {
arr = value instanceof Object[] ? (Object[]) value : new Object[] { value };
}
// process if needed
for (int i = 0; i < arr.length; i++) {
arr[i] = processIfNeeded(processor, arr[i]);
}
columns.add(arr);
}
// aggs without any grouping
else {
throw new SqlIllegalArgumentException("Unexpected non-agg/grouped column specified; %s", ref.getClass());
}
if (ref.depth() > maxDepth) {
maxDepth = ref.depth();
}
}
clearScroll(response.getScrollId());
return new AggsRowSetCursor(schema, columns, maxDepth, query.limit());
}
private static Object getAggProperty(Aggregations aggs, String path) {
List<String> list = AggregationPath.parse(path).getPathElementsAsStringList();
String aggName = list.get(0);
InternalAggregation agg = aggs.get(aggName);
if (agg == null) {
throw new ExecutionException("Cannot find an aggregation named %s", aggName);
}
return agg.getProperty(list.subList(1, list.size()));
}
private Object processIfNeeded(ColumnsProcessor processor, Object value) {
return processor != null ? processor.apply(value) : value;
}
AggsScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, QueryContainer query) {
super(listener, client, keepAlive, schema);
this.query = query;
}
/**
* Initial scroll used for parsing search hits (handles possible aggs).
*/
static class HandshakeScrollActionListener extends SearchHitsActionListener {
private final QueryContainer query;
@Override
protected RowSetCursor handleResponse(SearchResponse response) {
Aggregations aggs = response.getAggregations();
HandshakeScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, QueryContainer query) {
super(listener, client, keepAlive, schema, query.limit(), 0);
this.query = query;
}
List<Object[]> columns = new ArrayList<>();
@Override
public void onResponse(SearchResponse response) {
super.onResponse(response);
}
// this method assumes the nested aggregation are all part of the same tree (the SQL group-by)
int maxDepth = -1;
@Override
protected List<HitExtractor> getExtractors() {
// create response extractors for the first time
List<Reference> refs = query.refs();
for (Reference ref : query.refs()) {
Object[] arr = null;
List<HitExtractor> exts = new ArrayList<>(refs.size());
for (Reference ref : refs) {
exts.add(createExtractor(ref));
}
return exts;
}
private HitExtractor createExtractor(Reference ref) {
if (ref instanceof SearchHitFieldRef) {
SearchHitFieldRef f = (SearchHitFieldRef) ref;
return f.useDocValue() ? new DocValueExtractor(f.name()) : new SourceExtractor(f.name());
}
if (ref instanceof NestedFieldRef) {
NestedFieldRef f = (NestedFieldRef) ref;
return new InnerHitExtractor(f.parent(), f.name(), f.useDocValue());
}
if (ref instanceof ScriptFieldRef) {
ScriptFieldRef f = (ScriptFieldRef) ref;
return new DocValueExtractor(f.name());
}
ColumnProcessor processor = null;
if (ref instanceof ProcessingRef) {
ProcessingRef pRef = (ProcessingRef) ref;
return new ProcessingHitExtractor(createExtractor(pRef.ref()), pRef.processor());
processor = pRef.processor();
ref = pRef.ref();
}
throw new SqlIllegalArgumentException("Unexpected ValueReference %s", ref.getClass());
}
}
/**
* Listener used for streaming the rest of the results after the handshake has been used.
*/
static class SessionScrollActionListener extends SearchHitsActionListener {
private List<HitExtractor> exts;
SessionScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, List<HitExtractor> ext, int limit, int docCount) {
super(listener, client, keepAlive, schema, limit, docCount);
this.exts = ext;
}
@Override
protected List<HitExtractor> getExtractors() {
return exts;
}
}
abstract static class SearchHitsActionListener extends ScrollerActionListener {
final int limit;
int docsRead;
SearchHitsActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, int limit,
int docsRead) {
super(listener, client, keepAlive, schema);
this.limit = limit;
this.docsRead = docsRead;
}
protected RowSetCursor handleResponse(SearchResponse response) {
SearchHit[] hits = response.getHits().getHits();
List<HitExtractor> exts = getExtractors();
// there are some results
if (hits.length > 0) {
String scrollId = response.getScrollId();
Consumer<ActionListener<RowSetCursor>> next = null;
docsRead += hits.length;
// if there's an id, try to setup next scroll
if (scrollId != null) {
// is all the content already retrieved?
if (Boolean.TRUE.equals(response.isTerminatedEarly()) || response.getHits().getTotalHits() == hits.length
// or maybe the limit has been reached
|| docsRead >= limit) {
// if so, clear the scroll
clearScroll(scrollId);
// and remove it to indicate no more data is expected
scrollId = null;
} else {
next = l -> Scroller.from(l, this, response.getScrollId(), exts);
if (ref == TotalCountRef.INSTANCE) {
arr = new Object[] { processIfNeeded(processor, Long.valueOf(response.getHits().getTotalHits())) };
columns.add(arr);
}
else if (ref instanceof AggRef) {
// workaround for elastic/elasticsearch/issues/23056
String path = ((AggRef) ref).path();
boolean formattedKey = AggPath.isBucketValueFormatted(path);
if (formattedKey) {
path = AggPath.bucketValueWithoutFormat(path);
}
Object value = getAggProperty(aggs, path);
// // FIXME: this can be tabular in nature
// if (ref instanceof MappedAggRef) {
// Map<String, Object> map = (Map<String, Object>) value;
// Object extractedValue = map.get(((MappedAggRef) ref).fieldName());
// }
if (formattedKey) {
List<? extends Bucket> buckets = ((MultiBucketsAggregation) value).getBuckets();
arr = new Object[buckets.size()];
for (int i = 0; i < buckets.size(); i++) {
arr[i] = buckets.get(i).getKeyAsString();
}
}
int limitHits = limit > 0 && docsRead >= limit ? limit : -1;
return new SearchHitRowSetCursor(schema, exts, hits, limitHits, scrollId, next);
}
// no hits
else {
clearScroll(response.getScrollId());
// typically means last page but might be an aggs only query
return needsHit(exts) ? Rows.empty(schema) : new SearchHitRowSetCursor(schema, exts);
}
}
private static boolean needsHit(List<HitExtractor> exts) {
for (HitExtractor ext : exts) {
if (ext instanceof DocValueExtractor || ext instanceof ProcessingHitExtractor) {
return true;
else {
arr = value instanceof Object[] ? (Object[]) value : new Object[] { value };
}
// process if needed
for (int i = 0; i < arr.length; i++) {
arr[i] = processIfNeeded(processor, arr[i]);
}
columns.add(arr);
}
// aggs without any grouping
else {
throw new SqlIllegalArgumentException("Unexpected non-agg/grouped column specified; %s", ref.getClass());
}
if (ref.depth() > maxDepth) {
maxDepth = ref.depth();
}
return false;
}
protected abstract List<HitExtractor> getExtractors();
clearScroll(response.getScrollId());
return new AggsRowSetCursor(schema, columns, maxDepth, query.limit());
}
abstract static class ScrollerActionListener implements ActionListener<SearchResponse> {
final ActionListener<RowSetCursor> listener;
final Client client;
final TimeValue keepAlive;
final Schema schema;
ScrollerActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema) {
this.listener = listener;
this.client = client;
this.keepAlive = keepAlive;
this.schema = schema;
private static Object getAggProperty(Aggregations aggs, String path) {
List<String> list = AggregationPath.parse(path).getPathElementsAsStringList();
String aggName = list.get(0);
InternalAggregation agg = aggs.get(aggName);
if (agg == null) {
throw new ExecutionException("Cannot find an aggregation named %s", aggName);
}
return agg.getProperty(list.subList(1, list.size()));
}
// TODO: need to handle rejections plus check failures (shard size, etc...)
@Override
public void onResponse(final SearchResponse response) {
try {
ShardSearchFailure[] failure = response.getShardFailures();
if (!ObjectUtils.isEmpty(failure)) {
onFailure(new ExecutionException(failure[0].reason(), failure[0].getCause()));
}
listener.onResponse(handleResponse(response));
} catch (Exception ex) {
onFailure(ex);
}
}
protected abstract RowSetCursor handleResponse(SearchResponse response);
protected final void clearScroll(String scrollId) {
if (scrollId != null) {
// fire and forget
client.prepareClearScroll().addScrollId(scrollId).execute();
}
}
@Override
public final void onFailure(Exception ex) {
listener.onFailure(ex);
}
private Object processIfNeeded(ColumnProcessor processor, Object value) {
return processor != null ? processor.apply(value) : value;
}
}
// initial scroll used for parsing search hits (handles possible aggs)
class HandshakeScrollActionListener extends SearchHitsActionListener {
private final QueryContainer query;
HandshakeScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, QueryContainer query) {
super(listener, client, keepAlive, schema, query.limit(), 0);
this.query = query;
}
@Override
public void onResponse(SearchResponse response) {
super.onResponse(response);
}
@Override
protected List<HitExtractor> getExtractors() {
// create response extractors for the first time
List<Reference> refs = query.refs();
List<HitExtractor> exts = new ArrayList<>(refs.size());
for (Reference ref : refs) {
exts.add(createExtractor(ref));
}
return exts;
}
private HitExtractor createExtractor(Reference ref) {
if (ref instanceof SearchHitFieldRef) {
SearchHitFieldRef f = (SearchHitFieldRef) ref;
return f.useDocValue() ? new DocValueExtractor(f.name()) : new SourceExtractor(f.name());
}
if (ref instanceof NestedFieldRef) {
NestedFieldRef f = (NestedFieldRef) ref;
return new InnerHitExtractor(f.parent(), f.name(), f.useDocValue());
}
if (ref instanceof ScriptFieldRef) {
ScriptFieldRef f = (ScriptFieldRef) ref;
return new DocValueExtractor(f.name());
}
if (ref instanceof ProcessingRef) {
ProcessingRef pRef = (ProcessingRef) ref;
return new ProcessingHitExtractor(createExtractor(pRef.ref()), pRef.processor());
}
throw new SqlIllegalArgumentException("Unexpected ValueReference %s", ref.getClass());
}
}
// listener used for streaming the rest of the results after the handshake has been used
class SessionScrollActionListener extends SearchHitsActionListener {
private List<HitExtractor> exts;
SessionScrollActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, List<HitExtractor> ext, int limit, int docCount) {
super(listener, client, keepAlive, schema, limit, docCount);
this.exts = ext;
}
@Override
protected List<HitExtractor> getExtractors() {
return exts;
}
}
abstract class SearchHitsActionListener extends ScrollerActionListener {
final int limit;
int docsRead;
SearchHitsActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema, int limit, int docsRead) {
super(listener, client, keepAlive, schema);
this.limit = limit;
this.docsRead = docsRead;
}
protected RowSetCursor handleResponse(SearchResponse response) {
SearchHit[] hits = response.getHits().getHits();
List<HitExtractor> exts = getExtractors();
// there are some results
if (hits.length > 0) {
String scrollId = response.getScrollId();
Consumer<ActionListener<RowSetCursor>> next = null;
docsRead += hits.length;
// if there's an id, try to setup next scroll
if (scrollId != null) {
// is all the content already retrieved?
if (Boolean.TRUE.equals(response.isTerminatedEarly()) || response.getHits().getTotalHits() == hits.length
// or maybe the limit has been reached
|| docsRead >= limit) {
// if so, clear the scroll
clearScroll(scrollId);
// and remove it to indicate no more data is expected
scrollId = null;
}
else {
next = l -> Scroller.from(l, this, response.getScrollId(), exts);
}
}
int limitHits = limit > 0 && docsRead >= limit ? limit : -1;
return new SearchHitRowSetCursor(schema, exts, hits, limitHits, scrollId, next);
}
// no hits
else {
clearScroll(response.getScrollId());
// typically means last page but might be an aggs only query
return needsHit(exts) ? Rows.empty(schema) : new SearchHitRowSetCursor(schema, exts);
}
}
private static boolean needsHit(List<HitExtractor> exts) {
for (HitExtractor ext : exts) {
if (ext instanceof DocValueExtractor || ext instanceof ProcessingHitExtractor) {
return true;
}
}
return false;
}
protected abstract List<HitExtractor> getExtractors();
}
abstract class ScrollerActionListener implements ActionListener<SearchResponse> {
final ActionListener<RowSetCursor> listener;
final Client client;
final TimeValue keepAlive;
final Schema schema;
ScrollerActionListener(ActionListener<RowSetCursor> listener, Client client, TimeValue keepAlive, Schema schema) {
this.listener = listener;
this.client = client;
this.keepAlive = keepAlive;
this.schema = schema;
}
// TODO: need to handle rejections plus check failures (shard size, etc...)
@Override
public void onResponse(final SearchResponse response) {
try {
ShardSearchFailure[] failure = response.getShardFailures();
if (!ObjectUtils.isEmpty(failure)) {
onFailure(new ExecutionException(failure[0].reason(), failure[0].getCause()));
}
listener.onResponse(handleResponse(response));
} catch (Exception ex) {
onFailure(ex);
}
}
protected abstract RowSetCursor handleResponse(SearchResponse response);
protected final void clearScroll(String scrollId) {
if (scrollId != null) {
// fire and forget
client.prepareClearScroll().addScrollId(scrollId).execute();
}
}
@Override
public final void onFailure(Exception ex) {
listener.onFailure(ex);
}
}

View File

@ -51,6 +51,7 @@ public class SearchHitRowSetCursor extends AbstractRowSetCursor {
InnerHitExtractor ie = getInnerHitExtractor(ex);
if (ie != null) {
innerH = ie.parent();
innerHits.add(innerH);
}
}

View File

@ -14,7 +14,7 @@ import org.elasticsearch.xpack.sql.type.DataType;
abstract class UnresolvedNamedExpression extends NamedExpression implements Unresolvable {
UnresolvedNamedExpression(Location location, List<Expression> children) {
public UnresolvedNamedExpression(Location location, List<Expression> children) {
super(location, "<unresolved>", children, ExpressionIdGenerator.EMPTY);
}

View File

@ -71,7 +71,7 @@ abstract class AbstractFunctionRegistry implements FunctionRegistry {
@Override
public Collection<FunctionDefinition> listFunctions(String pattern) {
Pattern p = Strings.hasText(pattern) ? Pattern.compile(normalize(pattern)) : null;
Pattern p = Strings.hasText(pattern) ? StringUtils.likeRegex(normalize(pattern)) : null;
return defs.entrySet().stream()
.filter(e -> p == null || p.matcher(e.getKey()).matches())
.map(e -> new FunctionDefinition(e.getKey(), emptyList(), e.getValue().clazz()))

View File

@ -6,11 +6,18 @@
package org.elasticsearch.xpack.sql.expression.function;
import org.elasticsearch.xpack.sql.SqlException;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Avg;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Mean;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness;
import org.elasticsearch.xpack.sql.expression.function.aggregate.StddevPop;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Sum;
import org.elasticsearch.xpack.sql.expression.function.aggregate.SumOfSquares;
import org.elasticsearch.xpack.sql.expression.function.aggregate.VarPop;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfMonth;
import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DayOfWeek;
@ -72,13 +79,21 @@ public class DefaultFunctionRegistry extends AbstractFunctionRegistry {
return ALIASES;
}
private static Collection<Class<? extends Function>> agg() {
private static Collection<Class<? extends AggregateFunction>> agg() {
return Arrays.asList(
Avg.class,
Count.class,
Max.class,
Min.class,
Sum.class
Sum.class,
// statistics
Mean.class,
StddevPop.class,
VarPop.class,
SumOfSquares.class,
Skewness.class,
Kurtosis.class
// TODO: add multi arg functions like Covariance, Correlate, Percentiles and percentiles rank
);
}
@ -128,6 +143,7 @@ public class DefaultFunctionRegistry extends AbstractFunctionRegistry {
);
}
@SuppressWarnings("unchecked")
private static Collection<Class<? extends ScalarFunction>> functions(Class<? extends ScalarFunction> type) {
String path = type.getPackage().getName().replace('.', '/');

View File

@ -67,4 +67,8 @@ public abstract class Function extends NamedExpression {
}
return sj.toString();
}
public boolean functionEquals(Function f) {
return f != null && getClass() == f.getClass() && arguments().equals(f.arguments());
}
}

View File

@ -16,7 +16,7 @@ public enum FunctionType {
private final Class<? extends Function> baseClass;
FunctionType(Class<? extends Function> base) {
private FunctionType(Class<? extends Function> base) {
this.baseClass = base;
}

View File

@ -12,7 +12,7 @@ import org.elasticsearch.xpack.sql.expression.Alias;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import static java.util.Collections.emptyList;
@ -80,13 +80,13 @@ public abstract class Functions {
return exps;
}
public static ColumnsProcessor chainProcessors(List<Expression> unwrappedScalar) {
ColumnsProcessor proc = null;
public static ColumnProcessor chainProcessors(List<Expression> unwrappedScalar) {
ColumnProcessor proc = null;
for (Expression e : unwrappedScalar) {
if (e instanceof ScalarFunction) {
ScalarFunction sf = (ScalarFunction) e;
// A(B(C)) is applied backwards first C then B then A, the last function first
proc = proc != null ? sf.asProcessor().andThen(proc) : sf.asProcessor();
proc = sf.asProcessor().andThen(proc);
}
else {
return proc;

View File

@ -24,9 +24,13 @@ public abstract class AggregateFunction extends Function {
return argument;
}
public String functionId() {
return id().toString();
}
@Override
public AggregateFunctionAttribute toAttribute() {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
return new AggregateFunctionAttribute(location(), name(), dataType(), id(), id().toString(), null);
return new AggregateFunctionAttribute(location(), name(), dataType(), id(), functionId(), null);
}
}

View File

@ -49,6 +49,10 @@ public class AggregateFunctionAttribute extends TypedAttribute {
return new AggregateFunctionAttribute(location, name, dataType, qualifier, nullable, id, synthetic, functionId, propertyPath);
}
public AggregateFunctionAttribute withFunctionId(String functionId, String propertyPath) {
return new AggregateFunctionAttribute(location(), name(), dataType(), qualifier(), nullable(), id(), synthetic(), functionId, propertyPath);
}
@Override
public boolean equals(Object obj) {
if (super.equals(obj)) {

View File

@ -7,17 +7,15 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
public class Avg extends NumericAggregateFunction {
public class Avg extends NumericAggregate implements EnclosedAgg {
public Avg(Location location, Expression argument) {
super(location, argument);
}
@Override
public DataType dataType() {
return DataTypes.DOUBLE;
public String innerName() {
return "avg";
}
}

View File

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
// marker type for compound aggregates, that is aggregate that provide multiple values (like Stats or Matrix)
// and thus cannot be used directly in SQL and are mainly for internal use
public abstract class CompoundAggregate extends NumericAggregate {
public CompoundAggregate(Location location, Expression argument) {
super(location, argument);
}
}

View File

@ -29,13 +29,19 @@ public class Count extends AggregateFunction {
return DataTypes.LONG;
}
@Override
public AggregateFunctionAttribute toAttribute() {
public String functionId() {
String functionId = id().toString();
// if count works against a given expression, use its id (to identify the group)
if (argument() instanceof NamedExpression) {
functionId = ((NamedExpression) argument()).id().toString();
}
return new AggregateFunctionAttribute(location(), name(), dataType(), id(), functionId, "_count");
return functionId;
}
@Override
public AggregateFunctionAttribute toAttribute() {
return new AggregateFunctionAttribute(location(), name(), dataType(), id(), functionId(), "_count");
}
}

View File

@ -0,0 +1,11 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
public interface EnclosedAgg {
String innerName();
}

View File

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class ExtendedStats extends CompoundAggregate {
public ExtendedStats(Location location, Expression argument) {
super(location, argument);
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
public interface ExtendedStatsEnclosed extends StatsEnclosed, EnclosedAgg {
}

View File

@ -0,0 +1,78 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.querydsl.agg.AggPath;
import org.elasticsearch.xpack.sql.type.DataType;
public class InnerAggregate extends AggregateFunction {
private final AggregateFunction inner;
private final CompoundAggregate outer;
private final String innerId;
// used when the result needs to be extracted from a map (like in MatrixAggs)
private final Expression innerKey;
public InnerAggregate(AggregateFunction inner, CompoundAggregate outer) {
this(inner, outer, null);
}
public InnerAggregate(AggregateFunction inner, CompoundAggregate outer, Expression innerKey) {
super(inner.location(), outer.argument());
this.inner = inner;
this.outer = outer;
this.innerId = ((EnclosedAgg) inner).innerName();
this.innerKey = innerKey;
}
public AggregateFunction inner() {
return inner;
}
public CompoundAggregate outer() {
return outer;
}
public String innerId() {
return innerId;
}
public Expression innerKey() {
return innerKey;
}
@Override
public DataType dataType() {
return inner.dataType();
}
@Override
public String functionId() {
return outer.id().toString();
}
@Override
public AggregateFunctionAttribute toAttribute() {
// this is highly correlated with QueryFolder$FoldAggregate#addFunction (regarding the function name within the querydsl)
return new AggregateFunctionAttribute(location(), name(), dataType(), outer.id(), functionId(), AggPath.metricValue(functionId(), innerId));
}
@Override
public boolean functionEquals(Function f) {
if (super.equals(f)) {
InnerAggregate other = (InnerAggregate) f;
return inner.equals(other.inner) && outer.equals(other.outer);
}
return false;
}
@Override
public String name() {
return "(" + inner.functionName() + "#" + inner.id() + "/" + outer.toString() + ")";
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class Kurtosis extends NumericAggregate implements MatrixStatsEnclosed {
public Kurtosis(Location location, Expression argument) {
super(location, argument);
}
@Override
public String innerName() {
return "kurtosis";
}
}

View File

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class MatrixStats extends CompoundAggregate {
public MatrixStats(Location location, Expression argument) {
super(location, argument);
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
public interface MatrixStatsEnclosed extends EnclosedAgg {
}

View File

@ -7,10 +7,21 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
public class Max extends NumericAggregateFunction {
public class Max extends NumericAggregate implements EnclosedAgg {
public Max(Location location, Expression argument) {
super(location, argument);
}
@Override
public DataType dataType() {
return argument().dataType();
}
@Override
public String innerName() {
return "max";
}
}

View File

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
public class Mean extends NumericAggregate implements MatrixStatsEnclosed {
public Mean(Location location, Expression argument) {
super(location, argument);
}
@Override
public DataType dataType() {
return DataTypes.DOUBLE;
}
@Override
public String innerName() {
return "means";
}
}

View File

@ -7,10 +7,21 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
public class Min extends NumericAggregateFunction {
public class Min extends NumericAggregate implements EnclosedAgg {
public Min(Location location, Expression argument) {
super(location, argument);
}
@Override
public DataType dataType() {
return argument().dataType();
}
@Override
public String innerName() {
return "min";
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;
abstract class NumericAggregate extends AggregateFunction {
NumericAggregate(Location location, Expression argument) {
super(location, argument);
}
@Override
protected TypeResolution resolveType() {
return argument().dataType().isNumeric() ?
TypeResolution.TYPE_RESOLVED :
new TypeResolution("Function '%s' cannot be applied on a non-numeric expression ('%s' of type '%s')", functionName(), Expressions.name(argument()), argument().dataType().esName());
}
@Override
public DataType dataType() {
return DataTypes.DOUBLE;
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class Skewness extends NumericAggregate implements MatrixStatsEnclosed {
public Skewness(Location location, Expression argument) {
super(location, argument);
}
@Override
public String innerName() {
return "skewness";
}
}

View File

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class Stats extends CompoundAggregate {
public Stats(Location location, Expression argument) {
super(location, argument);
}
public static boolean isTypeCompatible(Expression e) {
return e instanceof Min || e instanceof Max || e instanceof Avg || e instanceof Sum;
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
public interface StatsEnclosed {
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class StddevPop extends NumericAggregate implements ExtendedStatsEnclosed {
public StddevPop(Location location, Expression argument) {
super(location, argument);
}
@Override
public String innerName() {
return "std_deviation";
}
}

View File

@ -7,10 +7,21 @@ package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
public class Sum extends NumericAggregateFunction {
public class Sum extends NumericAggregate implements EnclosedAgg {
public Sum(Location location, Expression argument) {
super(location, argument);
}
@Override
public DataType dataType() {
return argument().dataType();
}
@Override
public String innerName() {
return "sum";
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class SumOfSquares extends NumericAggregate implements ExtendedStatsEnclosed {
public SumOfSquares(Location location, Expression argument) {
super(location, argument);
}
@Override
public String innerName() {
return "sum_of_squares";
}
}

View File

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.expression.function.aggregate;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.tree.Location;
public class VarPop extends NumericAggregate implements ExtendedStatsEnclosed {
public VarPop(Location location, Expression argument) {
super(location, argument);
}
@Override
public String innerName() {
return "variance";
}
}

View File

@ -77,7 +77,7 @@ public class Cast extends ScalarFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return c -> DataTypeConvertion.convert(c, from(), to());
}

View File

@ -5,15 +5,12 @@
*/
package org.elasticsearch.xpack.sql.expression.function.scalar;
import java.util.Objects;
@FunctionalInterface
public interface ColumnsProcessor {
public interface ColumnProcessor {
Object apply(Object t);
Object apply(Object r);
default ColumnsProcessor andThen(ColumnsProcessor after) {
Objects.requireNonNull(after);
return t -> after.apply(apply(t));
default ColumnProcessor andThen(ColumnProcessor after) {
return after != null ? r -> after.apply(apply(r)) : this;
}
}

View File

@ -81,7 +81,7 @@ public abstract class ScalarFunction extends Function {
protected abstract String chainScalarTemplate(String template);
public abstract ColumnsProcessor asProcessor();
public abstract ColumnProcessor asProcessor();
// used if the function is monotonic and thus does not have to be computed for ordering purposes
public Expression orderBy() {

View File

@ -9,7 +9,7 @@ import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.tree.Location;
@ -64,7 +64,7 @@ public abstract class DateTimeFunction extends ScalarFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return l -> {
ReadableDateTime dt = null;
// most dates are returned as long

View File

@ -6,7 +6,7 @@
package org.elasticsearch.xpack.sql.expression.function.scalar.math;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
@ -17,7 +17,7 @@ public class Abs extends MathFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return l -> {
if (l instanceof Float) {
return Math.abs(((Float) l).floatValue());

View File

@ -7,7 +7,7 @@ package org.elasticsearch.xpack.sql.expression.function.scalar.math;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.util.StringUtils;
@ -28,7 +28,7 @@ public class E extends MathFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return l -> Math.E;
}

View File

@ -10,7 +10,7 @@ import java.util.Locale;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate;
@ -79,7 +79,7 @@ public abstract class MathFunction extends ScalarFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return l -> {
double d = ((Number) l).doubleValue();
return math(d);

View File

@ -7,7 +7,7 @@ package org.elasticsearch.xpack.sql.expression.function.scalar.math;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.util.StringUtils;
@ -28,7 +28,7 @@ public class Pi extends MathFunction {
}
@Override
public ColumnsProcessor asProcessor() {
public ColumnProcessor asProcessor() {
return l -> Math.PI;
}

View File

@ -16,7 +16,7 @@ import org.elasticsearch.xpack.sql.type.DataTypes;
public class FullTextPredicate extends Expression {
public enum Operator {
public static enum Operator {
AND,
OR;

View File

@ -20,7 +20,14 @@ import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.NestedFieldAttribute;
import org.elasticsearch.xpack.sql.expression.Order;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStats;
import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Stats;
import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
import org.elasticsearch.xpack.sql.expression.predicate.And;
import org.elasticsearch.xpack.sql.expression.predicate.BinaryComparison;
@ -49,9 +56,11 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.sql.expression.Literal.FALSE;
@ -88,7 +97,11 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Batch aggregate = new Batch("Aggregation",
new PruneDuplicatesInGroupBy(),
new ReplaceDuplicateAggsWithReferences()
new ReplaceDuplicateAggsWithReferences(),
new CombineAggsToMatrixStats(),
new CombineAggsToExtendedStats(),
new CombineAggsToStats(),
new PromoteStatsToExtendedStats()
);
Batch cleanup = new Batch("Operator Optimization",
@ -112,6 +125,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
return Arrays.asList(resolution, aggregate, cleanup, label);
}
static class PruneSubqueryAliases extends OptimizerRule<SubQueryAlias> {
PruneSubqueryAliases() {
@ -136,13 +150,11 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Project p = (Project) plan;
return new Project(p.location(), p.child(), cleanExpressions(p.projections()));
}
if (plan instanceof Aggregate) {
Aggregate a = (Aggregate) plan;
// clean group expressions
List<Expression> cleanedGroups = a.groupings().stream()
.map(this::trimAliases)
.collect(toList());
List<Expression> cleanedGroups = a.groupings().stream().map(this::trimAliases).collect(toList());
return new Aggregate(a.location(), a.child(), cleanedGroups, cleanExpressions(a.aggregates()));
}
@ -155,9 +167,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
private List<NamedExpression> cleanExpressions(List<? extends NamedExpression> args) {
return args.stream()
.map(this::trimNonTopLevelAliases)
.map(NamedExpression.class::cast)
return args.stream().map(this::trimNonTopLevelAliases).map(NamedExpression.class::cast)
.collect(toList());
}
@ -211,7 +221,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
reverse.putIfAbsent(ne, ne.canonical());
}
}
if (unique.size() != aggs.size()) {
List<NamedExpression> newAggs = new ArrayList<>(aggs.size());
for (NamedExpression ne : aggs) {
@ -224,6 +234,206 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
}
static class CombineAggsToMatrixStats extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
Map<Expression, MatrixStats> seen = new LinkedHashMap<>();
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
p = p.transformExpressionsUp(e -> rule(e, seen, promotedFunctionIds));
return p.transformExpressionsDown(e -> CombineAggsToStats.updateFunctionAttrs(e, promotedFunctionIds));
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
protected Expression rule(Expression e, Map<Expression, MatrixStats> seen, Map<String, AggregateFunctionAttribute> promotedIds) {
if (e instanceof MatrixStatsEnclosed) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
MatrixStats matrixStats = seen.get(argument);
if (matrixStats == null) {
matrixStats = new MatrixStats(f.location(), argument);
seen.put(argument, matrixStats);
}
InnerAggregate ia = new InnerAggregate(f, matrixStats, f.argument());
promotedIds.putIfAbsent(f.functionId(), ia.toAttribute());
return ia;
}
return e;
}
}
static class CombineAggsToExtendedStats extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
Map<Expression, ExtendedStats> seen = new LinkedHashMap<>();
p = p.transformExpressionsUp(e -> rule(e, seen, promotedFunctionIds));
// update old agg attributes
return p.transformExpressionsDown(e -> CombineAggsToStats.updateFunctionAttrs(e, promotedFunctionIds));
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
protected Expression rule(Expression e, Map<Expression, ExtendedStats> seen, Map<String, AggregateFunctionAttribute> promotedIds) {
if (e instanceof ExtendedStatsEnclosed) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
ExtendedStats extendedStats = seen.get(argument);
if (extendedStats == null) {
extendedStats = new ExtendedStats(f.location(), argument);
seen.put(argument, extendedStats);
}
InnerAggregate ia = new InnerAggregate(f, extendedStats);
promotedIds.putIfAbsent(f.functionId(), ia.toAttribute());
return ia;
}
return e;
}
}
static class CombineAggsToStats extends Rule<LogicalPlan, LogicalPlan> {
private static class Counter {
final Stats stats;
int count = 1;
final Set<Class<? extends AggregateFunction>> functionTypes = new LinkedHashSet<>();
Counter(Stats stats) {
this.stats = stats;
}
}
@Override
public LogicalPlan apply(LogicalPlan p) {
Map<Expression, Counter> potentialPromotions = new LinkedHashMap<>();
// old functionId to new function attribute
Map<String, AggregateFunctionAttribute> promotedFunctionIds = new LinkedHashMap<>();
p.forEachExpressionsUp(e -> count(e, potentialPromotions));
// promote aggs to InnerAggs
p = p.transformExpressionsUp(e -> promote(e, potentialPromotions, promotedFunctionIds));
// update old agg attributes (TODO: this might be applied while updating the InnerAggs since the promotion happens bottom-up (and thus any attributes should be only in higher nodes)
return p.transformExpressionsDown(e -> updateFunctionAttrs(e, promotedFunctionIds));
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
private Expression count(Expression e, Map<Expression, Counter> seen) {
if (Stats.isTypeCompatible(e)) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Counter counter = seen.get(argument);
if (counter == null) {
counter = new Counter(new Stats(f.location(), argument));
counter.functionTypes.add(f.getClass());
seen.put(argument, counter);
}
else {
if (counter.functionTypes.add(f.getClass())) {
counter.count++;
}
}
}
return e;
}
private Expression promote(Expression e, Map<Expression, Counter> seen, Map<String, AggregateFunctionAttribute> attrs) {
if (Stats.isTypeCompatible(e)) {
AggregateFunction f = (AggregateFunction) e;
Expression argument = f.argument();
Counter counter = seen.get(argument);
// if the stat has at least two different functions for it, promote it as stat
if (counter != null && counter.count > 1) {
InnerAggregate innerAgg = new InnerAggregate(f, counter.stats);
attrs.putIfAbsent(f.functionId(), innerAgg.toAttribute());
return innerAgg;
}
}
return e;
}
static Expression updateFunctionAttrs(Expression e, Map<String, AggregateFunctionAttribute> promotedIds) {
if (e instanceof AggregateFunctionAttribute) {
AggregateFunctionAttribute ae = (AggregateFunctionAttribute) e;
AggregateFunctionAttribute promoted = promotedIds.get(ae.functionId());
if (promoted != null) {
return ae.withFunctionId(promoted.functionId(), promoted.propertyPath());
}
}
return e;
}
}
static class PromoteStatsToExtendedStats extends Rule<LogicalPlan, LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
Map<Expression, ExtendedStats> seen = new LinkedHashMap<>();
// count the extended stats
p.forEachExpressionsUp(e -> count(e, seen));
// then if there's a match, replace the stat inside the InnerAgg
return p.transformExpressionsUp(e -> promote(e, seen));
}
@Override
protected LogicalPlan rule(LogicalPlan e) {
return e;
}
private void count(Expression e, Map<Expression, ExtendedStats> seen) {
if (e instanceof InnerAggregate) {
InnerAggregate ia = (InnerAggregate) e;
if (ia.outer() instanceof ExtendedStats) {
ExtendedStats extStats = (ExtendedStats) ia.outer();
seen.putIfAbsent(extStats.argument(), extStats);
}
}
}
protected Expression promote(Expression e, Map<Expression, ExtendedStats> seen) {
if (e instanceof InnerAggregate) {
InnerAggregate ia = (InnerAggregate) e;
if (ia.outer() instanceof Stats) {
Stats stats = (Stats) ia.outer();
ExtendedStats ext = seen.get(stats.argument());
if (ext != null && stats.argument().equals(ext.argument())) {
return new InnerAggregate(ia.inner(), ext);
}
}
}
return e;
}
}
static class PruneFilters extends OptimizerRule<Filter> {
@Override
@ -291,14 +501,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// count the direct parents
Map<String, Order> nestedOrders = new LinkedHashMap<>();
for (Order order : ob.order()) {
Attribute attr = ((NamedExpression) order.child()).toAttribute();
if (attr instanceof NestedFieldAttribute) {
nestedOrders.put(((NestedFieldAttribute) attr).parentPath(), order);
}
}
// no nested fields in sort
if (nestedOrders.isEmpty()) {
return project;
@ -346,7 +556,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
}
return project;
}
}
}
static class PruneOrderBy extends OptimizerRule<OrderBy> {
@ -356,10 +566,8 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
List<Order> order = ob.order();
// remove constants
List<Order> nonConstant = order.stream()
.filter(o -> !o.child().foldable())
.collect(toList());
List<Order> nonConstant = order.stream().filter(o -> !o.child().foldable()).collect(toList());
if (nonConstant.isEmpty()) {
return ob.child();
}
@ -371,14 +579,12 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (a.groupings().isEmpty()) {
AttributeSet aggsAttr = new AttributeSet(Expressions.asAttributes(a.aggregates()));
List<Order> nonAgg = nonConstant.stream()
.filter(o -> {
if (o.child() instanceof NamedExpression) {
return !aggsAttr.contains(((NamedExpression) o.child()).toAttribute());
}
return true;
})
.collect(toList());
List<Order> nonAgg = nonConstant.stream().filter(o -> {
if (o.child() instanceof NamedExpression) {
return !aggsAttr.contains(((NamedExpression) o.child()).toAttribute());
}
return true;
}).collect(toList());
return nonAgg.isEmpty() ? ob.child() : new OrderBy(ob.location(), ob.child(), nonAgg);
}
@ -409,7 +615,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override
protected LogicalPlan rule(LogicalPlan plan) {
final Map<Attribute, Attribute> replacedCast = new LinkedHashMap<>();
// first eliminate casts inside Aliases
LogicalPlan transformed = plan.transformExpressionsUp(e -> {
// cast wrapped in an alias
@ -445,13 +651,13 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
return e;
});
// replace attributes from previous removed Casts
if (!replacedCast.isEmpty()) {
return transformed.transformUp(p -> {
List<Attribute> newProjections = new ArrayList<>();
boolean changed = false;
for (NamedExpression ne : p.projections()) {
Attribute found = replacedCast.get(ne.toAttribute());
@ -463,9 +669,9 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
newProjections.add(ne.toAttribute());
}
}
return changed ? new Project(p.location(), p.child(), newProjections) : p;
}, Project.class);
}
return transformed;
@ -477,7 +683,6 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
@Override
public LogicalPlan apply(LogicalPlan p) {
List<Function> seen = new ArrayList<>();
return p.transformExpressionsUp(e -> rule(e, seen));
}
@ -491,19 +696,14 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
if (e instanceof Function) {
Function f = (Function) e;
for (Function seenFunction : seen) {
if (seenFunction != f && functionsEquals(f, seenFunction)) {
if (seenFunction != f && f.functionEquals(seenFunction)) {
return seenFunction;
}
}
seen.add(f);
}
return exp;
}
private boolean functionsEquals(Function f, Function seenFunction) {
return f.name().equals(seenFunction.name()) && f.arguments().equals(seenFunction.arguments());
}
}
static class SkipQueryOnLimitZero extends OptimizerRule<Limit> {
@ -548,13 +748,21 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
private Expression simplifyAndOr(BinaryExpression bc) {
Expression l = bc.left();
Expression r = bc.right();
if (bc instanceof And) {
if (TRUE.equals(l)) { return r; }
if (TRUE.equals(r)) { return l; }
if (FALSE.equals(l) || FALSE.equals(r)) { return FALSE; }
if (l.canonicalEquals(r)) { return l; }
if (bc instanceof And) {
if (TRUE.equals(l)) {
return r;
}
if (TRUE.equals(r)) {
return l;
}
if (FALSE.equals(l) || FALSE.equals(r)) {
return FALSE;
}
if (l.canonicalEquals(r)) {
return l;
}
//
// common factor extraction -> (a || b) && (a || c) => a && (b || c)
@ -577,14 +785,22 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
Expression combineRight = combineOr(rDiff);
return combineOr(CollectionUtils.combine(common, new And(combineLeft.location(), combineLeft, combineRight)));
}
if (bc instanceof Or) {
if (TRUE.equals(l) || TRUE.equals(r)) { return TRUE; }
if (TRUE.equals(l)) { return r; }
if (TRUE.equals(r)) { return l; }
if (l.canonicalEquals(r)) { return l; }
if (bc instanceof Or) {
if (TRUE.equals(l) || TRUE.equals(r)) {
return TRUE;
}
if (TRUE.equals(l)) {
return r;
}
if (TRUE.equals(r)) {
return l;
}
if (l.canonicalEquals(r)) {
return l;
}
//
// common factor extraction -> (a && b) || (a && c) => a || (b & c)
@ -615,8 +831,12 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
private Expression simplifyNot(Not n) {
Expression c = n.child();
if (TRUE.equals(c)) { return FALSE; }
if (FALSE.equals(c)) { return TRUE; }
if (TRUE.equals(c)) {
return FALSE;
}
if (FALSE.equals(c)) {
return TRUE;
}
if (c instanceof Negateable) {
return ((Negateable) c).negate();
@ -686,20 +906,24 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
// if the same operator is used
BinaryComparison lb = (BinaryComparison) l;
BinaryComparison rb = (BinaryComparison) r;
if (lb.left().equals(((BinaryComparison) r).left()) && lb.right() instanceof Literal && rb.right() instanceof Literal) {
// >/>= AND </<=
if ((l instanceof GreaterThan || l instanceof GreaterThanOrEqual) && (r instanceof LessThan || r instanceof LessThanOrEqual)) {
return new Range(and.location(), lb.left(), lb.right(), l instanceof GreaterThanOrEqual, rb.right(), r instanceof LessThanOrEqual);
if ((l instanceof GreaterThan || l instanceof GreaterThanOrEqual)
&& (r instanceof LessThan || r instanceof LessThanOrEqual)) {
return new Range(and.location(), lb.left(), lb.right(), l instanceof GreaterThanOrEqual, rb.right(),
r instanceof LessThanOrEqual);
}
// </<= AND >/>=
else if ((r instanceof GreaterThan || r instanceof GreaterThanOrEqual) && (l instanceof LessThan || l instanceof LessThanOrEqual)) {
return new Range(and.location(), rb.left(), rb.right(), r instanceof GreaterThanOrEqual, lb.right(), l instanceof LessThanOrEqual);
else if ((r instanceof GreaterThan || r instanceof GreaterThanOrEqual)
&& (l instanceof LessThan || l instanceof LessThanOrEqual)) {
return new Range(and.location(), rb.left(), rb.right(), r instanceof GreaterThanOrEqual, lb.right(),
l instanceof LessThanOrEqual);
}
}
}
return and;
}
}
@ -723,7 +947,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
}
abstract static class OptimizerRule<SubPlan extends LogicalPlan> extends Rule<SubPlan, LogicalPlan> {
static abstract class OptimizerRule<SubPlan extends LogicalPlan> extends Rule<SubPlan, LogicalPlan> {
private final boolean transformDown;
@ -745,7 +969,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
protected abstract LogicalPlan rule(SubPlan plan);
}
abstract static class OptimizerExpressionUpRule extends Rule<LogicalPlan, LogicalPlan> {
static abstract class OptimizerExpressionUpRule extends Rule<LogicalPlan, LogicalPlan> {
private final boolean transformDown;
@ -768,7 +992,7 @@ public class Optimizer extends RuleExecutor<LogicalPlan> {
protected LogicalPlan rule(LogicalPlan plan) {
return plan;
}
protected abstract Expression rule(Expression e);
}
}
}

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.sql.parser;
import java.util.Locale;
import org.elasticsearch.common.Booleans;
import org.elasticsearch.xpack.sql.parser.SqlBaseLexer;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.DebugContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ExplainContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.SessionResetContext;

View File

@ -97,21 +97,25 @@ public abstract class QueryPlan<PlanType extends QueryPlan<PlanType>> extends No
}
public void forEachExpressionsDown(Consumer<? super Expression> rule) {
forEachPropertiesDown(p -> doForEachExpression(p, rule), Object.class);
forEachPropertiesDown(e -> doForEachExpression(e, exp -> exp.forEachDown(rule)), Object.class);
}
public void forEachExpressionsUp(Consumer<? super Expression> rule) {
forEachPropertiesUp(p -> doForEachExpression(p, rule), Object.class);
forEachPropertiesUp(e -> doForEachExpression(e, exp -> exp.forEachUp(rule)), Object.class);
}
private void doForEachExpression(Object arg, Consumer<? super Expression> f) {
public void forEachExpressions(Consumer<? super Expression> rule) {
forEachPropertiesOnly(e -> doForEachExpression(e, rule::accept), Object.class);
}
private void doForEachExpression(Object arg, Consumer<? super Expression> traversal) {
if (arg instanceof Expression) {
f.accept((Expression) arg);
traversal.accept((Expression) arg);
}
else if (arg instanceof Collection) {
Collection<?> c = (Collection<?>) arg;
for (Object o : c) {
doForEachExpression(o, f);
doForEachExpression(o, traversal);
}
}
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataTypes;
import org.elasticsearch.xpack.sql.util.StringUtils;
import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
@ -52,7 +53,7 @@ public class SessionReset extends Command {
Settings defaults = session.defaults().cfg();
Builder builder = Settings.builder().put(s);
if (pattern != null) {
Pattern p = Pattern.compile(pattern);
Pattern p = StringUtils.likeRegex(pattern);
s.getAsMap().forEach((k, v) -> {
if (p.matcher(k).matches()) {
builder.put(k, defaults.get(k));

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.sql.session.Rows;
import org.elasticsearch.xpack.sql.session.SqlSession;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataTypes;
import org.elasticsearch.xpack.sql.util.StringUtils;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
@ -57,7 +58,7 @@ public class ShowSession extends Command {
}
else {
if (pattern != null) {
Pattern p = Pattern.compile(pattern);
Pattern p = StringUtils.likeRegex(pattern);
s = s.filter(k -> p.matcher(k).matches());
}

View File

@ -22,8 +22,10 @@ import org.elasticsearch.xpack.sql.expression.RootFieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.Function;
import org.elasticsearch.xpack.sql.expression.function.Functions;
import org.elasticsearch.xpack.sql.expression.function.aggregate.AggregateFunction;
import org.elasticsearch.xpack.sql.expression.function.aggregate.CompoundAggregate;
import org.elasticsearch.xpack.sql.expression.function.aggregate.Count;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.aggregate.InnerAggregate;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.plan.physical.AggregateExec;
@ -34,11 +36,13 @@ import org.elasticsearch.xpack.sql.plan.physical.OrderExec;
import org.elasticsearch.xpack.sql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.sql.plan.physical.ProjectExec;
import org.elasticsearch.xpack.sql.plan.physical.QuerylessExec;
import org.elasticsearch.xpack.sql.planner.QueryTranslator.GroupInfo;
import org.elasticsearch.xpack.sql.planner.QueryTranslator.GroupingContext;
import org.elasticsearch.xpack.sql.planner.QueryTranslator.QueryTranslation;
import org.elasticsearch.xpack.sql.querydsl.agg.AggFilter;
import org.elasticsearch.xpack.sql.querydsl.agg.AggPath;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.agg.GroupingAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg;
import org.elasticsearch.xpack.sql.querydsl.container.AttributeSort;
import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
import org.elasticsearch.xpack.sql.querydsl.container.ScriptSort;
@ -90,7 +94,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
QueryContainer queryC = exec.queryContainer();
Map<Attribute, Attribute> aliases = new LinkedHashMap<>(queryC.aliases());
Map<Attribute, ColumnsProcessor> processors = new LinkedHashMap<>(queryC.processors());
Map<Attribute, ColumnProcessor> processors = new LinkedHashMap<>(queryC.processors());
for (NamedExpression pj : project.projections()) {
if (pj instanceof Alias) {
@ -122,11 +126,11 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
return project;
}
private Attribute scalarToProcessor(ScalarFunction e, Map<Attribute, ColumnsProcessor> processors) {
private Attribute scalarToProcessor(ScalarFunction e, Map<Attribute, ColumnProcessor> processors) {
List<Expression> trail = Functions.unwrapScalarFunctionWithTail(e);
Expression tail = trail.get(trail.size() - 1);
ColumnsProcessor proc = Functions.chainProcessors(trail);
ColumnProcessor proc = Functions.chainProcessors(trail);
// in projection, scalar functions can only be applied to constants (in which case they are folded) or columns aka NamedExpressions
if (!(tail instanceof NamedExpression)) {
@ -222,16 +226,18 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// build the group aggregation
// and also collect info about it (since the group columns might be used inside the select)
GroupInfo groupInfo = QueryTranslator.groupBy(a.groupings());
GroupingContext groupingContext = QueryTranslator.groupBy(a.groupings());
// shortcut used in several places
Map<ExpressionId, GroupingAgg> aggMap = groupInfo != null ? groupInfo.aggMap : emptyMap();
Map<ExpressionId, GroupingAgg> groupMap = groupingContext != null ? groupingContext.groupMap : emptyMap();
QueryContainer queryC = exec.queryContainer();
if (groupInfo != null) {
queryC = queryC.addGroups(groupInfo.aggMap.values());
if (groupingContext != null) {
queryC = queryC.addGroups(groupingContext.groupMap.values());
}
Map<Attribute, Attribute> aliases = new LinkedHashMap<>();
// tracker for compound aggs seen in a group
Map<CompoundAggregate, String> compoundAggMap = new LinkedHashMap<>();
// followed by actual aggregates
for (NamedExpression ne : a.aggregates()) {
@ -257,14 +263,14 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
//
List<Expression> wrappingFunctions = Functions.unwrapScalarFunctionWithTail(child);
ColumnsProcessor proc = null;
ColumnProcessor proc = null;
Expression resolvedGroupedExp = null;
int resolvedExpIndex = -1;
// look-up the hierarchy to match the group
for (int i = wrappingFunctions.size() - 1; i >= 0 && resolvedGroupedExp == null; i--) {
Expression exp = wrappingFunctions.get(i);
parentGroup = groupInfo != null ? groupInfo.parentGroupFor(exp) : null;
parentGroup = groupingContext != null ? groupingContext.parentGroupFor(exp) : null;
// found group for expression or bumped into an aggregate (can happen when dealing with a root group)
if (parentGroup != null || Functions.isAggregateFunction(exp)) {
@ -284,10 +290,10 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
}
// initialize parent if needed
parentGroup = parentGroup == null && groupInfo != null ? groupInfo.parentGroupFor(resolvedGroupedExp) : parentGroup;
parentGroup = parentGroup == null && groupingContext != null ? groupingContext.parentGroupFor(resolvedGroupedExp) : parentGroup;
if (resolvedGroupedExp instanceof Attribute) {
queryC = useNamedReference(((Attribute) resolvedGroupedExp), proc, aggMap, queryC);
queryC = useNamedReference(((Attribute) resolvedGroupedExp), proc, groupMap, queryC);
}
// a scalar function can be used only if has been already used for grouping
@ -308,12 +314,12 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
throw new SqlIllegalArgumentException("Expected aggregate function inside alias; got %s", child.nodeString());
}
AggregateFunction f = (AggregateFunction) resolvedGroupedExp;
queryC = addFunction(parentGroup, f, proc, queryC);
queryC = addFunction(parentGroup, f, proc, compoundAggMap, queryC);
}
}
// not an Alias, means it's an Attribute
else {
queryC = useNamedReference(ne, null, aggMap, queryC);
queryC = useNamedReference(ne, null, groupMap, queryC);
}
}
@ -327,7 +333,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
// the agg is an actual value (field) that points to a group
// so look it up and create an extractor for it
private QueryContainer useNamedReference(NamedExpression ne, ColumnsProcessor proc, Map<ExpressionId, GroupingAgg> groupMap, QueryContainer queryC) {
private QueryContainer useNamedReference(NamedExpression ne, ColumnProcessor proc, Map<ExpressionId, GroupingAgg> groupMap, QueryContainer queryC) {
GroupingAgg aggInfo = groupMap.get(ne.id());
if (aggInfo == null) {
throw new SqlIllegalArgumentException("Cannot find group '%s'", ne.name());
@ -335,8 +341,8 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
return queryC.addAggRef(aggInfo.propertyPath(), proc);
}
private QueryContainer addFunction(GroupingAgg parentAgg, Function f, ColumnsProcessor proc, QueryContainer queryC) {
String functionId = f.id().toString();
private QueryContainer addFunction(GroupingAgg parentAgg, AggregateFunction f, ColumnProcessor proc, Map<CompoundAggregate, String> compoundAggMap, QueryContainer queryC) {
String functionId = f.functionId();
// handle count as a special case agg
if (f instanceof Count) {
Count c = (Count) f;
@ -344,9 +350,34 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
return queryC.addAggCount(parentAgg, functionId, proc);
}
}
// otherwise translate it to an agg
String parentPath = parentAgg != null ? parentAgg.asParentPath() : null;
String groupId = parentAgg != null ? parentAgg.id() : null;
if (f instanceof InnerAggregate) {
InnerAggregate ia = (InnerAggregate) f;
CompoundAggregate outer = ia.outer();
String cAggPath = compoundAggMap.get(outer);
// the compound agg hasn't been seen before so initialize it
if (cAggPath == null) {
LeafAgg leafAgg = toAgg(parentPath, functionId, outer);
cAggPath = leafAgg.propertyPath();
compoundAggMap.put(outer, cAggPath);
// add the agg without the default ref to it
queryC = queryC.with(queryC.aggs().addAgg(leafAgg));
}
String aggPath = AggPath.metricValue(cAggPath, ia.innerId());
// FIXME: concern leak - hack around MatrixAgg which is not generalized (afaik)
if (ia.innerKey() != null) {
proc = QueryTranslator.matrixFieldExtractor(ia.innerKey()).andThen(proc);
}
return queryC.addAggRef(aggPath, proc);
}
return queryC.addAgg(groupId, toAgg(parentPath, functionId, f), proc);
}
}
@ -470,19 +501,16 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
return exec.with(qContainer);
}
}
/**
* Rule for folding physical plans together.
*/
abstract static class FoldingRule<SubPlan extends PhysicalPlan> extends Rule<SubPlan, PhysicalPlan> {
@Override
public final PhysicalPlan apply(PhysicalPlan plan) {
return plan.transformUp(this::rule, typeToken());
}
@Override
protected abstract PhysicalPlan rule(SubPlan plan);
}
}
// rule for folding physical plans together
abstract class FoldingRule<SubPlan extends PhysicalPlan> extends Rule<SubPlan, PhysicalPlan> {
@Override
public final PhysicalPlan apply(PhysicalPlan plan) {
return plan.transformUp(this::rule, typeToken());
}
@Override
protected abstract PhysicalPlan rule(SubPlan plan);
}

View File

@ -9,7 +9,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import org.elasticsearch.xpack.sql.session.RowSetCursor;
public abstract class CliUtils { // TODO made public so it could be shared with tests
abstract class CliUtils {
// this toString is a bit convoluted since it tries to be smart and pad the columns according to their values
// as such it will look inside the row, find the max for each column and pad all the values accordingly
@ -17,7 +17,7 @@ public abstract class CliUtils { // TODO made public so it could be shared with
// a row needs to be iterated upon to fill up the values that don't take extra lines
// Warning: this method _consumes_ a rowset
public static String toString(RowSetCursor cursor) {
static String toString(RowSetCursor cursor) {
if (cursor.rowSize() == 1 && cursor.size() == 1 && cursor.column(0).toString().startsWith("digraph ")) {
return cursor.column(0).toString();
}

View File

@ -14,14 +14,6 @@ import org.elasticsearch.xpack.sql.util.StringUtils;
import static java.lang.String.format;
public abstract class Agg {
public static final char PATH_DELIMITER_CHAR = '>';
public static final String PATH_DELIMITER = String.valueOf(PATH_DELIMITER_CHAR);
public static final String PATH_BUCKET_VALUE = "._key";
public static final String PATH_BUCKET_COUNT = "._count";
public static final String PATH_BUCKET_VALUE_FORMATTED = "._key_as_string";
public static final String PATH_VALUE = ".value";
private final String id;
private final String fieldName;
private final String propertyPath;

View File

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import static org.elasticsearch.search.aggregations.AggregationBuilders.extendedStats;
public class ExtendedStatsAgg extends LeafAgg {
public ExtendedStatsAgg(String id, String propertyPath, String fieldName) {
super(id, propertyPath, fieldName);
}
@Override
AggregationBuilder toBuilder() {
return extendedStats(id()).field(fieldName());
}
}

View File

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
import java.util.List;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import static org.elasticsearch.search.aggregations.MatrixStatsAggregationBuilders.matrixStats;
public class MatrixStatsAgg extends LeafAgg {
private final List<String> fields;
public MatrixStatsAgg(String id, String propertyPath, List<String> fields) {
super(id, propertyPath, "<multi-field>");
this.fields = fields;
}
@Override
AggregationBuilder toBuilder() {
return matrixStats(id()).fields(fields);
}
}

View File

@ -0,0 +1,10 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
public abstract class MultiFieldAgg {
}

View File

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.sql.querydsl.agg;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import static org.elasticsearch.search.aggregations.AggregationBuilders.stats;
public class StatsAgg extends LeafAgg {
public StatsAgg(String id, String propertyPath, String fieldName) {
super(id, propertyPath, fieldName);
}
@Override
AggregationBuilder toBuilder() {
return stats(id()).field(fieldName());
}
}

View File

@ -5,24 +5,15 @@
*/
package org.elasticsearch.xpack.sql.querydsl.container;
import org.elasticsearch.xpack.sql.querydsl.agg.Agg;
import org.elasticsearch.xpack.sql.querydsl.agg.AggPath;
public class AggRef implements Reference {
private final String path;
// agg1 = 0
// agg1>agg2._count = 1
// agg1>agg2>agg3.value = 1 (agg3.value has the same depth as agg2._count)
// agg1>agg2>agg3._count = 2
private final int depth;
AggRef(String path) {
this.path = path;
int dpt = countCharIn(path, Agg.PATH_DELIMITER_CHAR);
if (path.endsWith(Agg.PATH_VALUE)) {
dpt = Math.max(0, dpt - 1);
}
depth = dpt;
depth = AggPath.depth(path);
}
@Override
@ -38,14 +29,4 @@ public class AggRef implements Reference {
public String path() {
return path;
}
private static int countCharIn(CharSequence sequence, char c) {
int count = 0;
for (int i = 0; i < sequence.length(); i++) {
if (c == sequence.charAt(i)) {
count++;
}
}
return count;
}
}
}

View File

@ -5,19 +5,19 @@
*/
package org.elasticsearch.xpack.sql.querydsl.container;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
public class ProcessingRef implements Reference {
private final ColumnsProcessor processor;
private final ColumnProcessor processor;
private final Reference ref;
public ProcessingRef(ColumnsProcessor processor, Reference ref) {
public ProcessingRef(ColumnProcessor processor, Reference ref) {
this.processor = processor;
this.ref = ref;
}
public ColumnsProcessor processor() {
public ColumnProcessor processor() {
return processor;
}

View File

@ -7,7 +7,6 @@ package org.elasticsearch.xpack.sql.querydsl.container;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@ -20,8 +19,8 @@ import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.FieldAttribute;
import org.elasticsearch.xpack.sql.expression.NestedFieldAttribute;
import org.elasticsearch.xpack.sql.expression.RootFieldAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnsProcessor;
import org.elasticsearch.xpack.sql.querydsl.agg.Agg;
import org.elasticsearch.xpack.sql.expression.function.scalar.ColumnProcessor;
import org.elasticsearch.xpack.sql.querydsl.agg.AggPath;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.agg.GroupingAgg;
import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg;
@ -29,6 +28,7 @@ import org.elasticsearch.xpack.sql.querydsl.query.AndQuery;
import org.elasticsearch.xpack.sql.querydsl.query.MatchAll;
import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery;
import org.elasticsearch.xpack.sql.querydsl.query.Query;
import org.elasticsearch.xpack.sql.util.CollectionUtils;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
@ -45,7 +45,7 @@ public class QueryContainer {
// aliases (maps an alias to its actual resolved attribute)
private final Map<Attribute, Attribute> aliases;
// processors for a given attribute - wraps the processor over the resolved ref
private final Map<Attribute, ColumnsProcessor> processors;
private final Map<Attribute, ColumnProcessor> processors;
// pseudo functions (like count) - that are 'extracted' from other aggs
private final Map<String, GroupingAgg> pseudoFunctions;
@ -60,7 +60,7 @@ public class QueryContainer {
this(null, null, null, null, null, null, null, -1);
}
public QueryContainer(Query query, Aggs aggs, List<Reference> refs, Map<Attribute, Attribute> aliases, Map<Attribute, ColumnsProcessor> processors, Map<String, GroupingAgg> pseudoFunctions, Set<Sort> sort, int limit) {
public QueryContainer(Query query, Aggs aggs, List<Reference> refs, Map<Attribute, Attribute> aliases, Map<Attribute, ColumnProcessor> processors, Map<String, GroupingAgg> pseudoFunctions, Set<Sort> sort, int limit) {
this.query = query;
this.aggs = aggs == null ? new Aggs() : aggs;
this.aliases = aliases == null || aliases.isEmpty() ? emptyMap() : aliases;
@ -105,7 +105,7 @@ public class QueryContainer {
return aliases;
}
public Map<Attribute, ColumnsProcessor> processors() {
public Map<Attribute, ColumnProcessor> processors() {
return processors;
}
@ -149,7 +149,7 @@ public class QueryContainer {
return new QueryContainer(query, aggs, refs, a, processors, pseudoFunctions, sort, limit);
}
public QueryContainer withProcessors(Map<Attribute, ColumnsProcessor> p) {
public QueryContainer withProcessors(Map<Attribute, ColumnProcessor> p) {
return new QueryContainer(query, aggs, refs, aliases, p, pseudoFunctions, sort, limit);
}
@ -177,7 +177,7 @@ public class QueryContainer {
}
private Reference wrapProcessorIfNeeded(Attribute attr, Reference ref) {
ColumnsProcessor columnProcessor = processors.get(attr);
ColumnProcessor columnProcessor = processors.get(attr);
return columnProcessor != null ? new ProcessingRef(columnProcessor, ref) : ref;
}
@ -244,24 +244,28 @@ public class QueryContainer {
return addAggRef(aggPath, null);
}
public QueryContainer addAggRef(String aggPath, ColumnsProcessor processor) {
Reference ref = new AggRef(aggPath);
ref = processor != null ? new ProcessingRef(processor, ref) : ref;
public QueryContainer addAggRef(String aggPath, ColumnProcessor processor) {
return addAggRef(new AggRef(aggPath), processor);
}
public QueryContainer addAggRef(AggRef customRef, ColumnProcessor processor) {
Reference ref = processor != null ? new ProcessingRef(processor, customRef) : customRef;
return addRef(ref);
}
public QueryContainer addAggCount(GroupingAgg parentGroup, String functionId, ColumnsProcessor processor) {
Reference ref = parentGroup == null ? TotalCountRef.INSTANCE : new AggRef(parentGroup.asParentPath() + Agg.PATH_BUCKET_COUNT);
public QueryContainer addAggCount(GroupingAgg parentGroup, String functionId, ColumnProcessor processor) {
Reference ref = parentGroup == null ? TotalCountRef.INSTANCE : new AggRef(AggPath.bucketCount(parentGroup.asParentPath()));
ref = processor != null ? new ProcessingRef(processor, ref) : ref;
Map<String, GroupingAgg> newFunc = new LinkedHashMap<>(pseudoFunctions);
newFunc.put(functionId, parentGroup);
return new QueryContainer(query, aggs, combine(refs, ref), aliases, processors, newFunc, sort, limit);
return new QueryContainer(query, aggs, combine(refs, ref), aliases, processors, combine(pseudoFunctions, CollectionUtils.of(functionId, parentGroup)), sort, limit);
}
public QueryContainer addAgg(String groupId, LeafAgg agg, ColumnsProcessor processor) {
Reference ref = new AggRef(agg.propertyPath());
ref = processor != null ? new ProcessingRef(processor, ref) : ref;
public QueryContainer addAgg(String groupId, LeafAgg agg, ColumnProcessor processor) {
return addAgg(groupId, agg, agg.propertyPath(), processor);
}
public QueryContainer addAgg(String groupId, LeafAgg agg, String aggRefPath, ColumnProcessor processor) {
AggRef aggRef = new AggRef(aggRefPath);
Reference ref = processor != null ? new ProcessingRef(processor, aggRef) : aggRef;
return new QueryContainer(query, aggs.addAgg(groupId, agg), combine(refs, ref), aliases, processors, pseudoFunctions, sort, limit);
}

View File

@ -77,21 +77,20 @@ public abstract class Node<T extends Node<T>> {
});
}
@SuppressWarnings("unchecked")
public <E> void forEachPropertiesOnly(Consumer<? super E> rule, Class<E> typeToken) {
forEachProperty((T) this, rule, typeToken);
forEachProperty(rule, typeToken);
}
public <E> void forEachPropertiesDown(Consumer<? super E> rule, Class<E> typeToken) {
forEachDown(e -> forEachProperty(e, rule, typeToken));
forEachDown(e -> e.forEachProperty(rule, typeToken));
}
public <E> void forEachPropertiesUp(Consumer<? super E> rule, Class<E> typeToken) {
forEachUp(e -> forEachProperty(e, rule, typeToken));
forEachUp(e -> e.forEachProperty(rule, typeToken));
}
@SuppressWarnings("unchecked")
private <E> void forEachProperty(T node, Consumer<? super E> rule, Class<E> typeToken) {
protected <E> void forEachProperty(Consumer<? super E> rule, Class<E> typeToken) {
for (Object prop : NodeUtils.properties(this)) {
// skip children (only properties are interesting)
if (prop != children && !children.contains(prop) && typeToken.isInstance(prop)) {

View File

@ -5,6 +5,10 @@
*/
package org.elasticsearch.xpack.sql.tree;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.util.Assert;
import org.elasticsearch.xpack.sql.util.ObjectUtils;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
@ -16,11 +20,6 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
import org.elasticsearch.xpack.sql.util.Assert;
import org.elasticsearch.xpack.sql.util.ObjectUtils;
import java.util.Objects;
import java.util.Set;
@ -42,6 +41,7 @@ public abstract class NodeUtils {
private static final String TO_STRING_IGNORE_PROP = "location";
private static final int TO_STRING_MAX_PROP = 10;
private static final int TO_STRING_MAX_WIDTH = 100;
private static final Map<Class<?>, NodeInfo> CACHE = new LinkedHashMap<>();
@ -146,9 +146,7 @@ public abstract class NodeUtils {
Parameter[] parameters = ctr.getParameters();
for (int paramIndex = 0; paramIndex < parameters.length; paramIndex++) {
Parameter param = parameters[paramIndex];
// NOCOMMIT - oh boy. this is worth digging into. I suppose we preserve these for now but I don't think this is safe to rely on.
Assert.isTrue(param.isNamePresent(), "Can't find constructor parameter names for [%s]. Is class debug information available?",
clazz.toGenericString());
Assert.isTrue(param.isNamePresent(), "Can't find constructor parameter names for [%s]. Is class debug information available?", clazz.toGenericString());
String paramName = param.getName();
if (paramName.equals("children")) {
@ -225,6 +223,7 @@ public abstract class NodeUtils {
List<?> children = tree.children();
// eliminate children (they are rendered as part of the tree)
int maxProperties = TO_STRING_MAX_PROP;
int maxWidth = 0;
Iterator<String> nameIterator = keySet.iterator();
boolean needsComma = false;
@ -242,7 +241,12 @@ public abstract class NodeUtils {
if (needsComma) {
sb.append(",");
}
sb.append(Objects.toString(object));
String stringValue = Objects.toString(object);
if (maxWidth + stringValue.length() > TO_STRING_MAX_WIDTH) {
stringValue = stringValue.substring(0, Math.max(0, TO_STRING_MAX_WIDTH - maxWidth)) + "~";
}
maxWidth += stringValue.length();
sb.append(stringValue);
needsComma = true;
}

View File

@ -22,12 +22,16 @@ public abstract class Types {
@SuppressWarnings("unchecked")
public static Map<String, DataType> fromEs(Map<String, Object> asMap) {
return startWalking((Map<String, Object>) asMap.get("properties"));
Map<String, Object> props = (Map<String, Object>) asMap.get("properties");
return props == null || props.isEmpty() ? emptyMap() : startWalking(props);
}
private static Map<String, DataType> startWalking(Map<String, Object> mapping) {
Map<String, DataType> translated = new LinkedHashMap<>();
if (mapping == null) {
return emptyMap();
}
for (Entry<String, Object> entry : mapping.entrySet()) {
walkMapping(entry.getKey(), entry.getValue(), translated);
}

View File

@ -5,14 +5,14 @@
*/
package org.elasticsearch.xpack.sql.util;
import org.elasticsearch.common.Strings;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.regex.Pattern;
import org.elasticsearch.common.Strings;
import static java.util.stream.Collectors.joining;
public abstract class StringUtils {
@ -59,6 +59,7 @@ public abstract class StringUtils {
return strings.stream().collect(joining("."));
}
//CamelCase to camel_case
public static String camelCaseToUnderscore(String string) {
if (!Strings.hasText(string)) {
return EMPTY;