SQL: add support for passing query parameters in REST API calls (#51029) (#51222)

* REST PreparedStatement-like query parameters are now supported in the form of an array of non-object, non-array values where ES SQL parser will try to infer the data type of the value being passed as parameter.

(cherry picked from commit 45b8bf619aecb1c03d7bc0cf06928dcc36005a66)
This commit is contained in:
Andrei Stefan 2020-01-20 16:40:19 +02:00 committed by GitHub
parent 543cc85b78
commit 2908b7e5fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 427 additions and 51 deletions

View File

@ -8,6 +8,7 @@
* <<sql-pagination>> * <<sql-pagination>>
* <<sql-rest-filtering>> * <<sql-rest-filtering>>
* <<sql-rest-columnar>> * <<sql-rest-columnar>>
* <<sql-rest-params>>
* <<sql-rest-fields>> * <<sql-rest-fields>>
[[sql-rest-overview]] [[sql-rest-overview]]
@ -337,7 +338,7 @@ Which will like return the
[[sql-rest-filtering]] [[sql-rest-filtering]]
=== Filtering using {es} query DSL === Filtering using {es} query DSL
You can filter the results that SQL will run on using a standard One can filter the results that SQL will run on using a standard
{es} query DSL by specifying the query in the filter {es} query DSL by specifying the query in the filter
parameter. parameter.
@ -442,6 +443,36 @@ Which looks like:
-------------------------------------------------- --------------------------------------------------
// TESTRESPONSE[s/46ToAwFzQERYRjFaWEo1UVc1a1JtVjBZMmdCQUFBQUFBQUFBQUVXWjBaNlFXbzNOV0pVY21Wa1NUZDJhV2t3V2xwblp3PT3\/\/\/\/\/DwQBZgZhdXRob3IBBHRleHQAAAFmBG5hbWUBBHRleHQAAAFmCnBhZ2VfY291bnQBBGxvbmcBAAFmDHJlbGVhc2VfZGF0ZQEIZGF0ZXRpbWUBAAEP/$body.cursor/] // TESTRESPONSE[s/46ToAwFzQERYRjFaWEo1UVc1a1JtVjBZMmdCQUFBQUFBQUFBQUVXWjBaNlFXbzNOV0pVY21Wa1NUZDJhV2t3V2xwblp3PT3\/\/\/\/\/DwQBZgZhdXRob3IBBHRleHQAAAFmBG5hbWUBBHRleHQAAAFmCnBhZ2VfY291bnQBBGxvbmcBAAFmDHJlbGVhc2VfZGF0ZQEIZGF0ZXRpbWUBAAEP/$body.cursor/]
[[sql-rest-params]]
=== Passing parameters to a query
Using values in a query condition, for example, or in a `HAVING` statement can be done "inline",
by integrating the value in the query string itself:
[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > 300 AND author = 'Frank Herbert' GROUP BY year HAVING COUNT(*) > 0"
}
--------------------------------------------------
// TEST[setup:library]
or it can be done by extracting the values in a separate list of parameters and using question mark placeholders (`?`) in the query string:
[source,console]
--------------------------------------------------
POST /_sql?format=txt
{
"query": "SELECT YEAR(release_date) AS year FROM library WHERE page_count > ? AND author = ? GROUP BY year HAVING COUNT(*) > ?",
"params": [300, "Frank Herbert", 0]
}
--------------------------------------------------
// TEST[setup:library]
[IMPORTANT]
The recommended way of passing values to a query is with question mark placeholders, to avoid any attempts of hacking or SQL injection.
[[sql-rest-fields]] [[sql-rest-fields]]
=== Supported REST parameters === Supported REST parameters
@ -495,6 +526,10 @@ More information available https://docs.oracle.com/javase/8/docs/api/java/time/Z
|false |false
|Whether to include <<frozen-indices, frozen-indices>> in the query execution or not (default). |Whether to include <<frozen-indices, frozen-indices>> in the query execution or not (default).
|params
|none
|Optional list of parameters to replace question mark (`?`) placeholders inside the query.
|=== |===
Do note that most parameters (outside the timeout and `columnar` ones) make sense only during the initial query - any follow-up pagination request only requires the `cursor` parameter as explained in the <<sql-pagination, pagination>> chapter. Do note that most parameters (outside the timeout and `columnar` ones) make sense only during the initial query - any follow-up pagination request only requires the `cursor` parameter as explained in the <<sql-pagination, pagination>> chapter.

View File

@ -544,8 +544,11 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
} else { } else {
expected.put("rows", Arrays.asList(Arrays.asList("foo", 10))); expected.put("rows", Arrays.asList(Arrays.asList("foo", 10)));
} }
String params = mode.equals("jdbc") ? "{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}" :
"10, \"foo\"";
assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT test, ? param FROM test WHERE test = ?\", " + assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT test, ? param FROM test WHERE test = ?\", " +
"\"params\":[{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}]" "\"params\":[" + params + "]"
+ mode(mode) + columnarParameter(columnar) + "}", ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode)); + mode(mode) + columnarParameter(columnar) + "}", ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
} }

View File

@ -12,7 +12,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser.ValueType;
import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.ToXContentFragment;
import org.elasticsearch.common.xcontent.XContentLocation;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentParser.Token;
import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.xpack.sql.proto.Mode; import org.elasticsearch.xpack.sql.proto.Mode;
@ -22,6 +27,7 @@ import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
import java.io.IOException; import java.io.IOException;
import java.time.ZoneId; import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@ -75,11 +81,11 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
parser.declareString(AbstractSqlQueryRequest::query, QUERY); parser.declareString(AbstractSqlQueryRequest::query, QUERY);
parser.declareString((request, mode) -> request.mode(Mode.fromString(mode)), MODE); parser.declareString((request, mode) -> request.mode(Mode.fromString(mode)), MODE);
parser.declareString((request, clientId) -> request.clientId(clientId), CLIENT_ID); parser.declareString((request, clientId) -> request.clientId(clientId), CLIENT_ID);
parser.declareObjectArray(AbstractSqlQueryRequest::params, (p, c) -> SqlTypedParamValue.fromXContent(p), PARAMS); parser.declareField(AbstractSqlQueryRequest::params, AbstractSqlQueryRequest::parseParams, PARAMS, ValueType.VALUE_ARRAY);
parser.declareString((request, zoneId) -> request.zoneId(ZoneId.of(zoneId)), TIME_ZONE); parser.declareString((request, zoneId) -> request.zoneId(ZoneId.of(zoneId)), TIME_ZONE);
parser.declareInt(AbstractSqlQueryRequest::fetchSize, FETCH_SIZE); parser.declareInt(AbstractSqlQueryRequest::fetchSize, FETCH_SIZE);
parser.declareString((request, timeout) -> request.requestTimeout(TimeValue.parseTimeValue(timeout, Protocol.REQUEST_TIMEOUT, parser.declareString((request, timeout) -> request.requestTimeout(TimeValue.parseTimeValue(timeout, Protocol.REQUEST_TIMEOUT,
"request_timeout")), REQUEST_TIMEOUT); "request_timeout")), REQUEST_TIMEOUT);
parser.declareString( parser.declareString(
(request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, "page_timeout")), (request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, "page_timeout")),
PAGE_TIMEOUT); PAGE_TIMEOUT);
@ -118,6 +124,87 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
return this; return this;
} }
private static List<SqlTypedParamValue> parseParams(XContentParser p) throws IOException {
List<SqlTypedParamValue> result = new ArrayList<>();
Token token = p.currentToken();
if (token == Token.START_ARRAY) {
Object value = null;
String type = null;
SqlTypedParamValue previousParam = null;
SqlTypedParamValue currentParam = null;
while ((token = p.nextToken()) != Token.END_ARRAY) {
XContentLocation loc = p.getTokenLocation();
if (token == Token.START_OBJECT) {
// we are at the start of a value/type pair... hopefully
currentParam = SqlTypedParamValue.fromXContent(p);
/*
* Always set the xcontentlocation for the first param just in case the first one happens to not meet the parsing rules
* that are checked later in validateParams method.
* Also, set the xcontentlocation of the param that is different from the previous param in list when it comes to
* its type being explicitly set or inferred.
*/
if ((previousParam != null && previousParam.hasExplicitType() == false) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
} else {
if (token == Token.VALUE_STRING) {
value = p.text();
type = "keyword";
} else if (token == Token.VALUE_NUMBER) {
XContentParser.NumberType numberType = p.numberType();
if (numberType == XContentParser.NumberType.INT) {
value = p.intValue();
type = "integer";
} else if (numberType == XContentParser.NumberType.LONG) {
value = p.longValue();
type = "long";
} else if (numberType == XContentParser.NumberType.FLOAT) {
value = p.floatValue();
type = "float";
} else if (numberType == XContentParser.NumberType.DOUBLE) {
value = p.doubleValue();
type = "double";
}
} else if (token == Token.VALUE_BOOLEAN) {
value = p.booleanValue();
type = "boolean";
} else if (token == Token.VALUE_NULL) {
value = null;
type = "null";
} else {
throw new XContentParseException(loc, "Failed to parse object: unexpected token [" + token + "] found");
}
currentParam = new SqlTypedParamValue(type, value, false);
if ((previousParam != null && previousParam.hasExplicitType() == true) || result.isEmpty()) {
currentParam.tokenLocation(loc);
}
}
result.add(currentParam);
previousParam = currentParam;
}
}
return result;
}
protected static void validateParams(List<SqlTypedParamValue> params, Mode mode) {
for(SqlTypedParamValue param : params) {
if (Mode.isDriver(mode) && param.hasExplicitType() == false) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is an object with a "
+ "value/type pair");
}
if (Mode.isDriver(mode) == false && param.hasExplicitType() == true) {
throw new XContentParseException(param.tokenLocation(), "[params] must be an array where each entry is a single field (no "
+ "objects supported)");
}
}
}
/** /**
* The client's time zone * The client's time zone
*/ */
@ -204,10 +291,11 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
public static void writeSqlTypedParamValue(StreamOutput out, SqlTypedParamValue value) throws IOException { public static void writeSqlTypedParamValue(StreamOutput out, SqlTypedParamValue value) throws IOException {
out.writeString(value.type); out.writeString(value.type);
out.writeGenericValue(value.value); out.writeGenericValue(value.value);
out.writeBoolean(value.hasExplicitType());
} }
public static SqlTypedParamValue readSqlTypedParamValue(StreamInput in) throws IOException { public static SqlTypedParamValue readSqlTypedParamValue(StreamInput in) throws IOException {
return new SqlTypedParamValue(in.readString(), in.readGenericValue()); return new SqlTypedParamValue(in.readString(), in.readGenericValue(), in.readBoolean());
} }
@Override @Override
@ -248,6 +336,6 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(super.hashCode(), query, zoneId, fetchSize, requestTimeout, pageTimeout, filter); return Objects.hash(super.hashCode(), query, params, zoneId, fetchSize, requestTimeout, pageTimeout, filter);
} }
} }

View File

@ -188,6 +188,8 @@ public class SqlQueryRequest extends AbstractSqlQueryRequest {
} }
public static SqlQueryRequest fromXContent(XContentParser parser) { public static SqlQueryRequest fromXContent(XContentParser parser) {
return PARSER.apply(parser, null); SqlQueryRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request;
} }
} }

View File

@ -58,6 +58,7 @@ public class SqlTranslateRequest extends AbstractSqlQueryRequest {
public static SqlTranslateRequest fromXContent(XContentParser parser) { public static SqlTranslateRequest fromXContent(XContentParser parser) {
SqlTranslateRequest request = PARSER.apply(parser, null); SqlTranslateRequest request = PARSER.apply(parser, null);
validateParams(request.params(), request.mode());
return request; return request;
} }

View File

@ -5,31 +5,37 @@
*/ */
package org.elasticsearch.xpack.sql.action; package org.elasticsearch.xpack.sql.action;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.SearchModule;
import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.sql.proto.Mode; import org.elasticsearch.xpack.sql.proto.Mode;
import org.elasticsearch.xpack.sql.proto.Protocol;
import org.elasticsearch.xpack.sql.proto.RequestInfo; import org.elasticsearch.xpack.sql.proto.RequestInfo;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue; import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;
import org.junit.Before; import org.junit.Before;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilter; import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilter;
import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilterOrNull; import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilterOrNull;
import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS; import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS;
public class SqlQueryRequestTests extends AbstractSerializingTestCase<SqlQueryRequest> { public class SqlQueryRequestTests extends AbstractWireSerializingTestCase<SqlQueryRequest> {
public RequestInfo requestInfo; public RequestInfo requestInfo;
@ -60,49 +66,16 @@ public class SqlQueryRequestTests extends AbstractSerializingTestCase<SqlQueryRe
); );
} }
private RequestInfo randomRequestInfo() {
return new RequestInfo(randomFrom(Mode.values()), randomFrom(randomFrom(CLIENT_IDS), requestInfo.clientId()));
}
public List<SqlTypedParamValue> randomParameters() {
if (randomBoolean()) {
return Collections.emptyList();
} else {
int len = randomIntBetween(1, 10);
List<SqlTypedParamValue> arr = new ArrayList<>(len);
for (int i = 0; i < len; i++) {
@SuppressWarnings("unchecked") Supplier<SqlTypedParamValue> supplier = randomFrom(
() -> new SqlTypedParamValue("boolean", randomBoolean()),
() -> new SqlTypedParamValue("long", randomLong()),
() -> new SqlTypedParamValue("double", randomDouble()),
() -> new SqlTypedParamValue("null", null),
() -> new SqlTypedParamValue("keyword", randomAlphaOfLength(10))
);
arr.add(supplier.get());
}
return Collections.unmodifiableList(arr);
}
}
@Override @Override
protected Writeable.Reader<SqlQueryRequest> instanceReader() { protected Writeable.Reader<SqlQueryRequest> instanceReader() {
return SqlQueryRequest::new; return SqlQueryRequest::new;
} }
private TimeValue randomTV() {
return TimeValue.parseTimeValue(randomTimeValue(), null, "test");
}
@Override
protected SqlQueryRequest doParseInstance(XContentParser parser) {
return SqlQueryRequest.fromXContent(parser);
}
@Override @Override
protected SqlQueryRequest mutateInstance(SqlQueryRequest instance) { protected SqlQueryRequest mutateInstance(SqlQueryRequest instance) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
Consumer<SqlQueryRequest> mutator = randomFrom( Consumer<SqlQueryRequest> mutator = randomFrom(
request -> request.requestInfo(randomValueOtherThan(request.requestInfo(), this::randomRequestInfo)), request -> mutateRequestInfo(instance, request),
request -> request.query(randomValueOtherThan(request.query(), () -> randomAlphaOfLength(5))), request -> request.query(randomValueOtherThan(request.query(), () -> randomAlphaOfLength(5))),
request -> request.params(randomValueOtherThan(request.params(), this::randomParameters)), request -> request.params(randomValueOtherThan(request.params(), this::randomParameters)),
request -> request.zoneId(randomValueOtherThan(request.zoneId(), ESTestCase::randomZone)), request -> request.zoneId(randomValueOtherThan(request.zoneId(), ESTestCase::randomZone)),
@ -120,9 +93,134 @@ public class SqlQueryRequestTests extends AbstractSerializingTestCase<SqlQueryRe
return newRequest; return newRequest;
} }
private AbstractSqlQueryRequest mutateRequestInfo(SqlQueryRequest oldRequest, SqlQueryRequest newRequest) {
RequestInfo requestInfo = randomValueOtherThan(newRequest.requestInfo(), this::randomRequestInfo);
newRequest.requestInfo(requestInfo);
if (Mode.isDriver(oldRequest.requestInfo().mode()) && Mode.isDriver(requestInfo.mode()) == false) {
for(SqlTypedParamValue param : oldRequest.params()) {
param.hasExplicitType(false);
}
}
if (Mode.isDriver(oldRequest.requestInfo().mode()) == false && Mode.isDriver(requestInfo.mode())) {
for(SqlTypedParamValue param : oldRequest.params()) {
param.hasExplicitType(true);
}
}
return newRequest;
}
public void testFromXContent() throws IOException {
xContentTester(this::createParser, this::createTestInstance, SqlQueryRequestTests::toXContent, this::doParseInstance)
.numberOfTestRuns(NUMBER_OF_TEST_RUNS)
.supportsUnknownFields(false)
.shuffleFieldsExceptions(Strings.EMPTY_ARRAY)
.randomFieldsExcludeFilter(field -> false)
.assertEqualsConsumer(this::assertEqualInstances)
.assertToXContentEquivalence(true)
.test();
}
public void testTimeZoneNullException() { public void testTimeZoneNullException() {
final SqlQueryRequest sqlQueryRequest = createTestInstance(); final SqlQueryRequest sqlQueryRequest = createTestInstance();
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> sqlQueryRequest.zoneId(null)); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> sqlQueryRequest.zoneId(null));
assertEquals("time zone may not be null.", e.getMessage()); assertEquals("time zone may not be null.", e.getMessage());
} }
private RequestInfo randomRequestInfo() {
return new RequestInfo(randomFrom(Mode.values()), randomFrom(randomFrom(CLIENT_IDS), requestInfo.clientId()));
}
private TimeValue randomTV() {
return TimeValue.parseTimeValue(randomTimeValue(), null, "test");
}
public List<SqlTypedParamValue> randomParameters() {
if (randomBoolean()) {
return Collections.emptyList();
} else {
int len = randomIntBetween(1, 10);
List<SqlTypedParamValue> arr = new ArrayList<>(len);
boolean hasExplicitType = Mode.isDriver(this.requestInfo.mode());
for (int i = 0; i < len; i++) {
@SuppressWarnings("unchecked") Supplier<SqlTypedParamValue> supplier = randomFrom(
() -> new SqlTypedParamValue("boolean", randomBoolean(), hasExplicitType),
() -> new SqlTypedParamValue("long", randomLong(), hasExplicitType),
() -> new SqlTypedParamValue("double", randomDouble(), hasExplicitType),
() -> new SqlTypedParamValue("null", null, hasExplicitType),
() -> new SqlTypedParamValue("keyword", randomAlphaOfLength(10), hasExplicitType)
);
arr.add(supplier.get());
}
return Collections.unmodifiableList(arr);
}
}
private SqlQueryRequest doParseInstance(XContentParser parser) {
return SqlQueryRequest.fromXContent(parser);
}
/**
* This is needed because {@link SqlQueryRequest#toXContent(XContentBuilder, org.elasticsearch.common.xcontent.ToXContent.Params)}
* is not serializing {@link SqlTypedParamValue} according to the request's {@link Mode} and it shouldn't, in fact.
* For testing purposes, different serializing methods for {@link SqlTypedParamValue} are necessary so that
* {@link SqlQueryRequest#fromXContent(XContentParser)} populates {@link SqlTypedParamValue#hasExplicitType()}
* properly.
*/
private static void toXContent(SqlQueryRequest request, XContentBuilder builder) throws IOException {
builder.startObject();
if (request.query() != null) {
builder.field("query", request.query());
}
builder.field("mode", request.mode().toString());
if (request.clientId() != null) {
builder.field("client_id", request.clientId());
}
if (request.params() != null && request.params().isEmpty() == false) {
builder.startArray("params");
for (SqlTypedParamValue val : request.params()) {
if (Mode.isDriver(request.mode())) {
builder.startObject();
builder.field("type", val.type);
builder.field("value", val.value);
builder.endObject();
} else {
builder.value(val.value);
}
}
builder.endArray();
}
if (request.zoneId() != null) {
builder.field("time_zone", request.zoneId().getId());
}
if (request.fetchSize() != Protocol.FETCH_SIZE) {
builder.field("fetch_size", request.fetchSize());
}
if (request.requestTimeout() != Protocol.REQUEST_TIMEOUT) {
builder.field("request_timeout", request.requestTimeout().getStringRep());
}
if (request.pageTimeout() != Protocol.PAGE_TIMEOUT) {
builder.field("page_timeout", request.pageTimeout().getStringRep());
}
if (request.filter() != null) {
builder.field("filter");
request.filter().toXContent(builder, ToXContent.EMPTY_PARAMS);
}
if (request.columnar() != null) {
builder.field("columnar", request.columnar());
}
if (request.fieldMultiValueLeniency()) {
builder.field("field_multi_value_leniency", request.fieldMultiValueLeniency());
}
if (request.indexIncludeFrozen()) {
builder.field("index_include_frozen", request.indexIncludeFrozen());
}
if (request.binaryCommunication() != null) {
builder.field("binary_format", request.binaryCommunication());
}
if (request.cursor() != null) {
builder.field("cursor", request.cursor());
}
builder.endObject();
}
} }

View File

@ -9,6 +9,7 @@ package org.elasticsearch.xpack.sql.action;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
@ -21,7 +22,7 @@ import java.util.List;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import static org.hamcrest.Matchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
public class SqlRequestParsersTests extends ESTestCase { public class SqlRequestParsersTests extends ESTestCase {
@ -105,23 +106,139 @@ public class SqlRequestParsersTests extends ESTestCase {
SqlQueryRequest::fromXContent); SqlQueryRequest::fromXContent);
Mode randomMode = randomFrom(Mode.values()); Mode randomMode = randomFrom(Mode.values());
String params;
List<SqlTypedParamValue> list = new ArrayList<>(1);
if (Mode.isDriver(randomMode)) {
params = "{\"value\":123, \"type\":\"whatever\"}";
list.add(new SqlTypedParamValue("whatever", 123, true));
} else {
params = "123";
list.add(new SqlTypedParamValue("integer", 123, false));
}
SqlQueryRequest request = generateRequest("{\"cursor\" : \"whatever\", \"mode\" : \"" SqlQueryRequest request = generateRequest("{\"cursor\" : \"whatever\", \"mode\" : \""
+ randomMode.toString() + "\", \"client_id\" : \"bla\"," + randomMode.toString() + "\", \"client_id\" : \"bla\","
+ "\"query\":\"select\",\"params\":[{\"value\":123, \"type\":\"whatever\"}], \"time_zone\":\"UTC\"," + "\"query\":\"select\","
+ "\"params\":[" + params + "],"
+ " \"time_zone\":\"UTC\","
+ "\"request_timeout\":\"5s\",\"page_timeout\":\"10s\"}", SqlQueryRequest::fromXContent); + "\"request_timeout\":\"5s\",\"page_timeout\":\"10s\"}", SqlQueryRequest::fromXContent);
assertNull(request.clientId()); assertNull(request.clientId());
assertEquals(randomMode, request.mode()); assertEquals(randomMode, request.mode());
assertEquals("whatever", request.cursor()); assertEquals("whatever", request.cursor());
assertEquals("select", request.query()); assertEquals("select", request.query());
List<SqlTypedParamValue> list = new ArrayList<>(1);
list.add(new SqlTypedParamValue("whatever", 123));
assertEquals(list, request.params()); assertEquals(list, request.params());
assertEquals("UTC", request.zoneId().getId()); assertEquals("UTC", request.zoneId().getId());
assertEquals(TimeValue.parseTimeValue("5s", "request_timeout"), request.requestTimeout()); assertEquals(TimeValue.parseTimeValue("5s", "request_timeout"), request.requestTimeout());
assertEquals(TimeValue.parseTimeValue("10s", "page_timeout"), request.pageTimeout()); assertEquals(TimeValue.parseTimeValue("10s", "page_timeout"), request.pageTimeout());
} }
public void testParamsSuccessfulParsingInDriverMode() throws IOException {
Mode driverMode = randomValueOtherThanMany((m) -> Mode.isDriver(m) == false, () -> randomFrom(Mode.values()));
String json = "{" +
" \"params\":[{\"type\":\"integer\",\"value\":35000},"
+ " {\"type\":\"date\",\"value\":\"1960-01-01\"},"
+ " {\"type\":\"boolean\",\"value\":false},"
+ " {\"type\":\"keyword\",\"value\":\"foo\"}]," +
" \"mode\": \"" + driverMode.toString() + "\"" +
"}";
SqlQueryRequest request = generateRequest(json, SqlQueryRequest::fromXContent);
List<SqlTypedParamValue> params = request.params();
assertEquals(4, params.size());
assertEquals(35000, params.get(0).value);
assertEquals("integer", params.get(0).type);
assertTrue(params.get(0).hasExplicitType());
assertEquals("1960-01-01", params.get(1).value);
assertEquals("date", params.get(1).type);
assertTrue(params.get(1).hasExplicitType());
assertEquals(false, params.get(2).value);
assertEquals("boolean", params.get(2).type);
assertTrue(params.get(2).hasExplicitType());
assertEquals("foo", params.get(3).value);
assertEquals("keyword", params.get(3).type);
assertTrue(params.get(3).hasExplicitType());
}
public void testParamsSuccessfulParsingInNonDriverMode() throws IOException {
Mode nonDriverMode = randomValueOtherThanMany(Mode::isDriver, () -> randomFrom(Mode.values()));
String json = "{" +
" \"params\":[35000,\"1960-01-01\",false,\"foo\"]," +
" \"mode\": \"" + nonDriverMode.toString() + "\"" +
"}";
SqlQueryRequest request = generateRequest(json, SqlQueryRequest::fromXContent);
List<SqlTypedParamValue> params = request.params();
assertEquals(4, params.size());
assertEquals(35000, params.get(0).value);
assertEquals("integer", params.get(0).type);
assertFalse(params.get(0).hasExplicitType());
assertEquals("1960-01-01", params.get(1).value);
assertEquals("keyword", params.get(1).type);
assertFalse(params.get(1).hasExplicitType());
assertEquals(false, params.get(2).value);
assertEquals("boolean", params.get(2).type);
assertFalse(params.get(2).hasExplicitType());
assertEquals("foo", params.get(3).value);
assertEquals("keyword", params.get(3).type);
assertFalse(params.get(3).hasExplicitType());
}
public void testParamsParsingFailure_QueryRequest_NonDriver() throws IOException {
Mode m = randomValueOtherThanMany(Mode::isDriver, () -> randomFrom(Mode.values()));
assertXContentParsingErrorMessage("{\"params\":[{\"whatever\":35000},\"1960-01-01\",false,\"foo\"],\"mode\": \""
+ m.toString() + "\"}",
"[sql/query] failed to parse field [params]",
SqlQueryRequest::fromXContent);
assertXContentParsingErrorMessage("{\"params\":[350.123,\"1960-01-01\",{\"foobar\":false},\"foo\"],\"mode\": \"}"
+ m.toString() + "\"}",
"[sql/query] failed to parse field [params]",
SqlQueryRequest::fromXContent);
assertXContentParsingErrorMessage("{\"mode\": \"" + m.toString() + "\",\"params\":[350.123,\"1960-01-01\",false,"
+ "{\"type\":\"keyword\",\"value\":\"foo\"}]}",
"[params] must be an array where each entry is a single field (no objects supported)",
SqlQueryRequest::fromXContent);
}
public void testParamsParsingFailure_TranslateRequest_NonDriver() throws IOException {
Mode m = randomValueOtherThanMany(Mode::isDriver, () -> randomFrom(Mode.values()));
assertXContentParsingErrorMessage("{\"params\":[{\"whatever\":35000},\"1960-01-01\",false,\"foo\"],\"mode\": \""
+ m.toString() + "\"}",
"[sql/query] failed to parse field [params]",
SqlTranslateRequest::fromXContent);
assertXContentParsingErrorMessage("{\"params\":[350.123,\"1960-01-01\",{\"foobar\":false},\"foo\"],\"mode\": \"}"
+ m.toString() + "\"}",
"[sql/query] failed to parse field [params]",
SqlTranslateRequest::fromXContent);
assertXContentParsingErrorMessage("{\"mode\": \"" + m.toString() + "\",\"params\":[350.123,\"1960-01-01\",false,"
+ "{\"type\":\"keyword\",\"value\":\"foo\"}]}",
"[params] must be an array where each entry is a single field (no objects supported)",
SqlTranslateRequest::fromXContent);
}
public void testParamsParsingFailure_Driver() throws IOException {
Mode m = randomValueOtherThanMany((t) -> Mode.isDriver(t) == false, () -> randomFrom(Mode.values()));
assertXContentParsingErrorMessage("{\"params\":[35000,{\"value\":\"1960-01-01\",\"type\":\"date\"},{\"value\":\"foo\","
+ "\"type\":\"keyword\"}],\"mode\": \"" + m.toString() + "\"}",
"[params] must be an array where each entry is an object with a value/type pair",
SqlQueryRequest::fromXContent);
assertXContentParsingErrorMessage("{\"params\":[{\"value\":10,\"type\":\"integer\"},{\"value\":\"1960-01-01\",\"type\":\"date\"},"
+ "false,\"foo\"],\"mode\": \"" + m.toString() + "\"}",
"[params] must be an array where each entry is an object with a value/type pair",
SqlQueryRequest::fromXContent);
assertXContentParsingErrorMessage("{\"mode\": \"" + m.toString() + "\",\"params\":[{\"value\":10,\"type\":\"integer\"},"
+ "{\"value\":\"1960-01-01\",\"type\":\"date\"},{\"foo\":\"bar\"}]}",
"[sql/query] failed to parse field [params]",
SqlQueryRequest::fromXContent);
}
private <R extends AbstractSqlRequest> R generateRequest(String json, Function<XContentParser, R> fromXContent) private <R extends AbstractSqlRequest> R generateRequest(String json, Function<XContentParser, R> fromXContent)
throws IOException { throws IOException {
XContentParser parser = parser(json); XContentParser parser = parser(json);
@ -140,6 +257,12 @@ public class SqlRequestParsersTests extends ESTestCase {
assertThat(e.getCause().getMessage(), containsString(errorMessage)); assertThat(e.getCause().getMessage(), containsString(errorMessage));
} }
private void assertXContentParsingErrorMessage(String json, String errorMessage, Consumer<XContentParser> consumer) throws IOException {
XContentParser parser = parser(json);
Exception e = expectThrows(XContentParseException.class, () -> consumer.accept(parser));
assertThat(e.getMessage(), containsString(errorMessage));
}
private XContentParser parser(String content) throws IOException { private XContentParser parser(String content) throws IOException {
XContentType xContentType = XContentType.JSON; XContentType xContentType = XContentType.JSON;
return xContentType.xContent().createParser( return xContentType.xContent().createParser(

View File

@ -10,6 +10,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentLocation;
import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException; import java.io.IOException;
@ -37,10 +38,33 @@ public class SqlTypedParamValue implements ToXContentObject {
public final Object value; public final Object value;
public final String type; public final String type;
private boolean hasExplicitType; // the type is explicitly set in the request or inferred by the parser
private XContentLocation tokenLocation; // location of the token failing the parsing rules
public SqlTypedParamValue(String type, Object value) { public SqlTypedParamValue(String type, Object value) {
this(type, value, true);
}
public SqlTypedParamValue(String type, Object value, boolean hasExplicitType) {
this.value = value; this.value = value;
this.type = type; this.type = type;
this.hasExplicitType = hasExplicitType;
}
public boolean hasExplicitType() {
return hasExplicitType;
}
public void hasExplicitType(boolean hasExplicitType) {
this.hasExplicitType = hasExplicitType;
}
public XContentLocation tokenLocation() {
return tokenLocation;
}
public void tokenLocation(XContentLocation tokenLocation) {
this.tokenLocation = tokenLocation;
} }
@Override @Override
@ -65,16 +89,18 @@ public class SqlTypedParamValue implements ToXContentObject {
return false; return false;
} }
SqlTypedParamValue that = (SqlTypedParamValue) o; SqlTypedParamValue that = (SqlTypedParamValue) o;
return Objects.equals(value, that.value) && Objects.equals(type, that.type); return Objects.equals(value, that.value)
&& Objects.equals(type, that.type)
&& Objects.equals(hasExplicitType, that.hasExplicitType);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(value, type); return Objects.hash(value, type, hasExplicitType);
} }
@Override @Override
public String toString() { public String toString() {
return String.valueOf(value) + "[" + type + "]"; return String.valueOf(value) + " [" + type + "][" + hasExplicitType + "][" + tokenLocation + "]";
} }
} }