SQL: Enable accurate hit tracking on demand (#39527)

Queries that require counting of all hits (COUNT(*) on implicit
group by), now enable accurate hit tracking.

Fix #37971

(cherry picked from commit 265b637cf6df08986a890b8b5daf012c2b0c1699)
This commit is contained in:
Costin Leau 2019-03-01 23:06:16 +02:00 committed by Costin Leau
parent f1a7166708
commit dfe81b260e
6 changed files with 107 additions and 37 deletions

View File

@ -153,6 +153,7 @@ public class CliExplainIT extends CliIntegrationTestCase {
assertThat(readLine(), startsWith(" }"));
assertThat(readLine(), startsWith(" }"));
assertThat(readLine(), startsWith(" ]"));
assertThat(readLine(), startsWith(" \"track_total_hits\" : 2147483647"));
assertThat(readLine(), startsWith("}]"));
assertEquals("", readLine());
}

View File

@ -169,6 +169,9 @@ public abstract class SourceGenerator {
// disable source fetching (only doc values are used)
disableSource(builder);
}
if (query.shouldTrackHits()) {
builder.trackTotalHits(true);
}
}
private static void disableSource(SearchSourceBuilder builder) {

View File

@ -14,7 +14,6 @@ import org.elasticsearch.xpack.sql.expression.AttributeMap;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.Literal;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.Order;
import org.elasticsearch.xpack.sql.expression.function.Function;
@ -152,7 +151,8 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
queryC.pseudoFunctions(),
new AttributeMap<>(processors),
queryC.sort(),
queryC.limit());
queryC.limit(),
queryC.shouldTrackHits());
return new EsQueryExec(exec.source(), exec.index(), project.output(), clone);
}
return project;
@ -180,7 +180,8 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
qContainer.pseudoFunctions(),
qContainer.scalarFunctions(),
qContainer.sort(),
qContainer.limit());
qContainer.limit(),
qContainer.shouldTrackHits());
return exec.with(qContainer);
}
@ -391,10 +392,16 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
if (f instanceof Count) {
Count c = (Count) f;
// COUNT(*) or COUNT(<literal>)
if (c.field() instanceof Literal) {
AggRef ref = groupingAgg == null ?
GlobalCountRef.INSTANCE :
new GroupByRef(groupingAgg.id(), Property.COUNT, null);
if (c.field().foldable()) {
AggRef ref = null;
if (groupingAgg == null) {
ref = GlobalCountRef.INSTANCE;
// if the count points to the total track hits, enable accurate count retrieval
queryC = queryC.withTrackHits();
} else {
ref = new GroupByRef(groupingAgg.id(), Property.COUNT, null);
}
Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(queryC.pseudoFunctions());
pseudoFunctions.put(functionId, groupingAgg);
@ -406,7 +413,7 @@ class QueryFolder extends RuleExecutor<PhysicalPlan> {
queryC = queryC.with(queryC.aggs().addAgg(leafAgg));
return new Tuple<>(queryC, a);
}
// the only variant left - COUNT(DISTINCT) - will be covered by the else branch below
// the only variant left - COUNT(DISTINCT) - will be covered by the else branch below as it maps to an aggregation
}
AggPathInput aggInput = null;

View File

@ -26,7 +26,6 @@ import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.querydsl.agg.Aggs;
import org.elasticsearch.xpack.sql.querydsl.agg.GroupByKey;
import org.elasticsearch.xpack.sql.querydsl.agg.LeafAgg;
import org.elasticsearch.xpack.sql.querydsl.container.GroupByRef.Property;
import org.elasticsearch.xpack.sql.querydsl.query.BoolQuery;
import org.elasticsearch.xpack.sql.querydsl.query.MatchAll;
import org.elasticsearch.xpack.sql.querydsl.query.NestedQuery;
@ -81,23 +80,26 @@ public class QueryContainer {
private final Set<Sort> sort;
private final int limit;
private final boolean trackHits;
// computed
private Boolean aggsOnly;
private Boolean customSort;
public QueryContainer() {
this(null, null, null, null, null, null, null, -1);
this(null, null, null, null, null, null, null, -1, false);
}
public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction, ExpressionId>> fields,
public QueryContainer(Query query,
Aggs aggs,
List<Tuple<FieldExtraction,
ExpressionId>> fields,
AttributeMap<Attribute> aliases,
Map<String, GroupByKey> pseudoFunctions,
AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort,
int limit) {
Map<String, GroupByKey> pseudoFunctions,
AttributeMap<Pipe> scalarFunctions,
Set<Sort> sort,
int limit,
boolean trackHits) {
this.query = query;
this.aggs = aggs == null ? Aggs.EMPTY : aggs;
this.fields = fields == null || fields.isEmpty() ? emptyList() : fields;
@ -106,6 +108,7 @@ public class QueryContainer {
this.scalarFunctions = scalarFunctions == null || scalarFunctions.isEmpty() ? AttributeMap.emptyAttributeMap() : scalarFunctions;
this.sort = sort == null || sort.isEmpty() ? emptySet() : sort;
this.limit = limit;
this.trackHits = trackHits;
}
/**
@ -230,38 +233,46 @@ public class QueryContainer {
return fields.size() > 0;
}
public boolean shouldTrackHits() {
return trackHits;
}
//
// copy methods
//
public QueryContainer with(Query q) {
return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}
public QueryContainer withAliases(AttributeMap<Attribute> a) {
return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, a, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}
public QueryContainer withPseudoFunctions(Map<String, GroupByKey> p) {
return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, p, scalarFunctions, sort, limit, trackHits);
}
public QueryContainer with(Aggs a) {
return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, a, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}
public QueryContainer withLimit(int l) {
return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l);
return l == limit ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, l, trackHits);
}
public QueryContainer withTrackHits() {
return trackHits ? this : new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, true);
}
public QueryContainer withScalarProcessors(AttributeMap<Pipe> procs) {
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, procs, sort, limit, trackHits);
}
public QueryContainer addSort(Sort sortable) {
Set<Sort> sort = new LinkedHashSet<>(this.sort);
sort.add(sortable);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit);
return new QueryContainer(query, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits);
}
private String aliasName(Attribute attr) {
@ -287,7 +298,7 @@ public class QueryContainer {
attr.field().isAggregatable(), attr.parent().name());
nestedRefs.add(nestedFieldRef);
return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit),
return new Tuple<>(new QueryContainer(q, aggs, fields, aliases, pseudoFunctions, scalarFunctions, sort, limit, trackHits),
nestedFieldRef);
}
@ -390,7 +401,7 @@ public class QueryContainer {
ExpressionId id = attr instanceof AggregateFunctionAttribute ? ((AggregateFunctionAttribute) attr).innerId() : attr.id();
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, id)), aliases, pseudoFunctions,
scalarFunctions,
sort, limit);
sort, limit, trackHits);
}
public AttributeMap<Pipe> scalarFunctions() {
@ -401,16 +412,6 @@ public class QueryContainer {
// agg methods
//
public QueryContainer addAggCount(GroupByKey group, ExpressionId functionId) {
FieldExtraction ref = group == null ? GlobalCountRef.INSTANCE : new GroupByRef(group.id(), Property.COUNT, null);
Map<String, GroupByKey> pseudoFunctions = new LinkedHashMap<>(this.pseudoFunctions);
pseudoFunctions.put(functionId.toString(), group);
return new QueryContainer(query, aggs, combine(fields, new Tuple<>(ref, functionId)),
aliases,
pseudoFunctions,
scalarFunctions, sort, limit);
}
public QueryContainer addAgg(String groupId, LeafAgg agg) {
return with(aggs.addAgg(agg));
}
@ -465,4 +466,4 @@ public class QueryContainer {
throw new RuntimeException("error rendering", e);
}
}
}
}

View File

@ -10,6 +10,7 @@ import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.test.ESTestCase;
@ -111,6 +112,13 @@ public class SourceGeneratorTests extends ESTestCase {
assertEquals(singletonList(fieldSort("_doc").order(SortOrder.ASC)), sourceBuilder.sorts());
}
public void testTrackHits() {
SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(new QueryContainer().withTrackHits(), null,
randomIntBetween(1, 10));
assertEquals("Should have tracked hits", Integer.valueOf(SearchContext.TRACK_TOTAL_HITS_ACCURATE),
sourceBuilder.trackTotalHitsUpTo());
}
public void testNoSortIfAgg() {
QueryContainer container = new QueryContainer()
.addGroups(singletonList(new GroupByValue("group_id", "group_column")))

View File

@ -678,4 +678,54 @@ public class QueryTranslatorTests extends ESTestCase {
"{\"date\":{\"order\":\"desc\",\"missing\":\"_last\",\"unmapped_type\":\"date\"}}]}}}}}"));
}
}
public void testGlobalCountInImplicitGroupByForcesTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(*) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertTrue("Should be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testGlobalCountAllInImplicitGroupByForcesTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(ALL *) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertTrue("Should be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testGlobalCountInSpecificGroupByDoesNotForceTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(*) FROM test GROUP BY int");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testFieldAllCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(ALL int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testFieldCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testDistinctCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT COUNT(DISTINCT int) FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
public void testNoCountDoesNotTrackHits() throws Exception {
PhysicalPlan p = optimizeAndPlan("SELECT int FROM test");
assertEquals(EsQueryExec.class, p.getClass());
EsQueryExec eqe = (EsQueryExec) p;
assertFalse("Should NOT be tracking hits", eqe.queryContainer().shouldTrackHits());
}
}