Add query context parameter to control limiting select rows (#14476)

* Add query context parameter to control limiting select rows

* Add unit tests

* Address review comments

* Address review comments

* Address review comments
This commit is contained in:
Adarsh Sanjeev 2023-06-28 17:54:24 +05:30 committed by GitHub
parent cb3a9d2b57
commit 233233c92d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 203 additions and 23 deletions

View File

@ -90,6 +90,7 @@ import org.apache.druid.msq.indexing.DataSourceMSQDestination;
import org.apache.druid.msq.indexing.InputChannelFactory;
import org.apache.druid.msq.indexing.InputChannelsImpl;
import org.apache.druid.msq.indexing.MSQControllerTask;
import org.apache.druid.msq.indexing.MSQSelectDestination;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher;
@ -467,7 +468,8 @@ public class ControllerImpl implements Controller
queryDef,
resultsYielder,
task.getQuerySpec().getColumnMappings(),
task.getSqlTypeNames()
task.getSqlTypeNames(),
MultiStageQueryContext.getSelectDestination(task.getQuerySpec().getQuery().context())
);
} else {
resultsReport = null;
@ -2032,7 +2034,8 @@ public class ControllerImpl implements Controller
final QueryDefinition queryDef,
final Yielder<Object[]> resultsYielder,
final ColumnMappings columnMappings,
@Nullable final List<SqlTypeName> sqlTypeNames
@Nullable final List<SqlTypeName> sqlTypeNames,
final MSQSelectDestination selectDestination
)
{
final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature();
@ -2047,7 +2050,7 @@ public class ControllerImpl implements Controller
);
}
return new MSQResultsReport(mappedSignature.build(), sqlTypeNames, resultsYielder, null);
return MSQResultsReport.createReportAndLimitRowsIfNeeded(mappedSignature.build(), sqlTypeNames, resultsYielder, selectDestination);
}
private static MSQStatusReport makeStatusReport(

View File

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing;
/**
* Determines the destination for results of select queries.
*/
public enum MSQSelectDestination
{
/**
* Writes all the results directly to the report.
*/
TASK_REPORT(false),
/**
* Writes the results as frame files to durable storage. Task report can be truncated to a preview.
*/
DURABLE_STORAGE(true);
private final boolean shouldTruncateResultsInTaskReport;
public boolean shouldTruncateResultsInTaskReport()
{
return shouldTruncateResultsInTaskReport;
}
MSQSelectDestination(boolean shouldTruncateResultsInTaskReport)
{
this.shouldTruncateResultsInTaskReport = shouldTruncateResultsInTaskReport;
}
}

View File

@ -24,10 +24,12 @@ import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.common.config.Configs;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.indexing.MSQSelectDestination;
import org.apache.druid.segment.column.ColumnType;
import javax.annotation.Nullable;
@ -44,30 +46,20 @@ public class MSQResultsReport
private final List<ColumnAndType> signature;
@Nullable
private final List<SqlTypeName> sqlTypeNames;
private final List<Object[]> results;
private final Yielder<Object[]> resultYielder;
private final boolean resultsTruncated;
public MSQResultsReport(
final List<ColumnAndType> signature,
@Nullable final List<SqlTypeName> sqlTypeNames,
Yielder<Object[]> resultYielder,
final Yielder<Object[]> resultYielder,
@Nullable Boolean resultsTruncated
)
{
this.signature = Preconditions.checkNotNull(signature, "signature");
this.sqlTypeNames = sqlTypeNames;
this.results = new ArrayList<>();
int rowCount = 0;
while (!resultYielder.isDone() && rowCount < Limits.MAX_SELECT_RESULT_ROWS) {
results.add(resultYielder.get());
resultYielder = resultYielder.next(null);
++rowCount;
}
if (resultsTruncated != null) {
this.resultsTruncated = !resultYielder.isDone() || resultsTruncated;
} else {
this.resultsTruncated = !resultYielder.isDone();
}
this.resultYielder = Preconditions.checkNotNull(resultYielder, "resultYielder");
this.resultsTruncated = Configs.valueOrDefault(resultsTruncated, false);
}
/**
@ -84,6 +76,27 @@ public class MSQResultsReport
return new MSQResultsReport(signature, sqlTypeNames, Yielders.each(Sequences.simple(results)), resultsTruncated);
}
public static MSQResultsReport createReportAndLimitRowsIfNeeded(
final List<ColumnAndType> signature,
@Nullable final List<SqlTypeName> sqlTypeNames,
Yielder<Object[]> resultYielder,
MSQSelectDestination selectDestination
)
{
if (selectDestination.shouldTruncateResultsInTaskReport()) {
List<Object[]> results = new ArrayList<>();
int rowCount = 0;
while (!resultYielder.isDone() && rowCount < Limits.MAX_SELECT_RESULT_ROWS) {
results.add(resultYielder.get());
resultYielder = resultYielder.next(null);
++rowCount;
}
return new MSQResultsReport(signature, sqlTypeNames, Yielders.each(Sequences.simple(results)), !resultYielder.isDone());
} else {
return new MSQResultsReport(signature, sqlTypeNames, resultYielder, false);
}
}
@JsonProperty("signature")
public List<ColumnAndType> getSignature()
{
@ -99,9 +112,9 @@ public class MSQResultsReport
}
@JsonProperty("results")
public List<Object[]> getResults()
public Yielder<Object[]> getResultYielder()
{
return results;
return resultYielder;
}
@JsonProperty("resultsTruncted")

View File

@ -27,6 +27,7 @@ import com.opencsv.RFC4180Parser;
import com.opencsv.RFC4180ParserBuilder;
import org.apache.druid.msq.exec.ClusterStatisticsMergeMode;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.indexing.MSQSelectDestination;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext;
@ -64,6 +65,10 @@ import java.util.stream.Collectors;
* Can be <b>PARALLEL</b>, <b>SEQUENTIAL</b> or <b>AUTO</b>. See {@link ClusterStatisticsMergeMode} for more information on each mode.
* Default value is <b>SEQUENTIAL</b></li>
*
* <li><b>selectDestination</b>: If the query is a Select, determines the location to write results to, once the query
* is finished. Depending on the location, the results might also be truncated to {@link Limits#MAX_SELECT_RESULT_ROWS}.
* Default value is {@link MSQSelectDestination#TASK_REPORT}, which writes all the results to the report.
*
* <li><b>useAutoColumnSchemas</b>: Temporary flag to allow experimentation using
* {@link org.apache.druid.segment.AutoTypeColumnSchema} for all 'standard' type columns during segment generation,
* see {@link DimensionSchemaUtils#createDimensionSchema} for more details.
@ -87,6 +92,8 @@ public class MultiStageQueryContext
public static final String CTX_DURABLE_SHUFFLE_STORAGE = "durableShuffleStorage";
private static final boolean DEFAULT_DURABLE_SHUFFLE_STORAGE = false;
public static final String CTX_SELECT_DESTINATION = "selectDestination";
private static final String DEFAULT_SELECT_DESTINATION = MSQSelectDestination.TASK_REPORT.toString();
public static final String CTX_FAULT_TOLERANCE = "faultTolerance";
public static final boolean DEFAULT_FAULT_TOLERANCE = false;
@ -204,6 +211,16 @@ public class MultiStageQueryContext
);
}
public static MSQSelectDestination getSelectDestination(final QueryContext queryContext)
{
return MSQSelectDestination.valueOf(
queryContext.getString(
CTX_SELECT_DESTINATION,
DEFAULT_SELECT_DESTINATION
)
);
}
public static int getRowsInMemory(final QueryContext queryContext)
{
return queryContext.getInt(CTX_ROWS_IN_MEMORY, DEFAULT_ROWS_IN_MEMORY);

View File

@ -35,12 +35,14 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.msq.indexing.MSQSelectDestination;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.test.CounterSnapshotMatcher;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.QueryDataSource;
@ -74,6 +76,7 @@ import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.util.CalciteTests;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import org.junit.runner.RunWith;
@ -88,6 +91,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -1852,6 +1856,71 @@ public class MSQSelectTest extends MSQTestBase
result.add(new Object[]{1});
}
Map<String, Object> queryContext = new HashMap<>(context);
queryContext.put(MultiStageQueryContext.CTX_SELECT_DESTINATION, MSQSelectDestination.DURABLE_STORAGE.toString());
testSelectQuery()
.setSql(StringUtils.format(
" SELECT 1 as \"timestamp\"\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [%s],\"type\":\"local\"}',\n"
+ " '{\"type\": \"csv\", \"hasHeaderRow\": true}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ")",
externalFiles
))
.setExpectedRowSignature(dummyRowSignature)
.setExpectedMSQSpec(
MSQSpec
.builder()
.query(newScanQueryBuilder()
.dataSource(new ExternalDataSource(
new LocalInputSource(null, null, Collections.nCopies(numFiles, toRead)),
new CsvInputFormat(null, null, null, true, 0),
RowSignature.builder().add("timestamp", ColumnType.STRING).build()
))
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("v0")
.virtualColumns(new ExpressionVirtualColumn("v0", ExprEval.of(1L).toExpr(), ColumnType.LONG))
.context(defaultScanQueryContext(
queryContext,
RowSignature.builder().add("v0", ColumnType.LONG).build()
))
.build()
)
.columnMappings(new ColumnMappings(
ImmutableList.of(
new ColumnMapping("v0", "timestamp")
)
))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build())
.setQueryContext(queryContext)
.setExpectedResultRows(result)
.verifyResults();
}
@Test
public void testSelectRowsGetUntruncatedInReportsByDefault() throws IOException
{
RowSignature dummyRowSignature = RowSignature.builder().add("timestamp", ColumnType.LONG).build();
final int numFiles = 200;
final File toRead = MSQTestFileUtils.getResourceAsTemporaryFile(temporaryFolder, this, "/wikipedia-sampled.json");
final String toReadFileNameAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath());
String externalFiles = String.join(", ", Collections.nCopies(numFiles, toReadFileNameAsJson));
List<Object[]> result = new ArrayList<>();
for (int i = 0; i < 3800; ++i) {
result.add(new Object[]{1});
}
Assert.assertTrue(result.size() > Limits.MAX_SELECT_RESULT_ROWS);
testSelectQuery()
.setSql(StringUtils.format(
" SELECT 1 as \"timestamp\"\n"

View File

@ -31,6 +31,7 @@ import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter;
import org.apache.druid.indexing.common.TaskReport;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.guice.MSQIndexingModule;
@ -50,6 +51,7 @@ import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -124,8 +126,13 @@ public class MSQTaskReportTest
Assert.assertEquals(report.getPayload().getStatus().getPendingTasks(), report2.getPayload().getStatus().getPendingTasks());
Assert.assertEquals(report.getPayload().getStages(), report2.getPayload().getStages());
final List<Object[]> results2 = report2.getPayload().getResults().getResults();
Yielder<Object[]> yielder = report2.getPayload().getResults().getResultYielder();
final List<Object[]> results2 = new ArrayList<>();
while (!yielder.isDone()) {
results2.add(yielder.get());
yielder = yielder.next(null);
}
Assert.assertEquals(results.size(), results2.size());
for (int i = 0; i < results.size(); i++) {
Assert.assertArrayEquals(results.get(i), results2.get(i));

View File

@ -69,6 +69,7 @@ import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.metadata.input.InputSourceModule;
@ -768,7 +769,20 @@ public class MSQTestBase extends BaseCalciteQueryTest
if (resultsReport == null) {
return null;
} else {
return resultsReport.getResults();
Yielder<Object[]> yielder = resultsReport.getResultYielder();
List<Object[]> rows = new ArrayList<>();
while (!yielder.isDone()) {
rows.add(yielder.get());
yielder = yielder.next(null);
}
try {
yielder.close();
}
catch (IOException e) {
throw new ISE("Unable to get results from the report");
}
return rows;
}
}

View File

@ -22,6 +22,7 @@ package org.apache.druid.msq.util;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.msq.indexing.MSQSelectDestination;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.query.BadQueryContextException;
import org.apache.druid.query.QueryContext;
@ -263,6 +264,12 @@ public class MultiStageQueryContextTest
Assert.assertEquals("nonStrict", MultiStageQueryContext.getMSQMode(QueryContext.of(propertyMap)));
}
@Test
public void limitSelectResultReturnsDefaultValue()
{
Assert.assertEquals(MSQSelectDestination.TASK_REPORT, MultiStageQueryContext.getSelectDestination(QueryContext.empty()));
}
@Test
public void testUseAutoSchemas()
{

View File

@ -31,6 +31,7 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.RetryUtils;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.http.client.response.StatusResponseHolder;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.indexing.report.MSQTaskReport;
@ -200,15 +201,17 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper<MsqQueryWithResu
List<Map<String, Object>> actualResults = new ArrayList<>();
List<Object[]> results = resultsReport.getResults();
Yielder<Object[]> yielder = resultsReport.getResultYielder();
List<MSQResultsReport.ColumnAndType> rowSignature = resultsReport.getSignature();
for (Object[] row : results) {
while (!yielder.isDone()) {
Object[] row = yielder.get();
Map<String, Object> rowWithFieldNames = new LinkedHashMap<>();
for (int i = 0; i < row.length; ++i) {
rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]);
}
actualResults.add(rowWithFieldNames);
yielder = yielder.next(null);
}
QueryResultVerifier.ResultVerificationObject resultsComparison = QueryResultVerifier.compareResults(