SQL: introduce a query builder for the Rest tests (#55094) (#56221)

* Introduce a query builder for the rest tests

The new BaseRestSqlTestCase.RequestObjectBuilder class is a helper class
to build REST request objects for the tests. Consequently, "manual" string
concatenation to form JSON is done away with.

The class mimics SqlQueryRequestBuilder API.

(cherry picked from commit c8363f04c029542c233a758e9286d33c51d9c0c4)
This commit is contained in:
Bogdan Pintea 2020-05-05 18:55:41 +02:00 committed by GitHub
parent e4f2c3105d
commit 23c35e32f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 357 additions and 196 deletions

View File

@ -23,10 +23,9 @@ import java.util.List;
import java.util.Map;
import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.mode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.query;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.randomMode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.toMap;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.version;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.SQL_QUERY_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.columnInfo;
@ -111,7 +110,7 @@ public class RestSqlMultinodeIT extends ESRestTestCase {
expected.put("rows", singletonList(singletonList(count)));
Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT);
request.setJsonEntity("{\"query\": \"SELECT COUNT(*) FROM test\"" + mode(mode) + version(mode) + "}");
request.setJsonEntity(query("SELECT COUNT(*) FROM test").mode(mode).toString());
Map<String, Object> actual = toMap(client.performRequest(request), mode);
if (false == expected.equals(actual)) {

View File

@ -30,9 +30,9 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.mode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.cursor;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.query;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.randomMode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.version;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.SQL_QUERY_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.columnInfo;
import static org.hamcrest.Matchers.containsString;
@ -71,11 +71,9 @@ public class RestSqlSecurityIT extends SqlSecurityTestCase {
public void expectScrollMatchesAdmin(String adminSql, String user, String userSql) throws Exception {
String mode = randomMode();
Map<String, Object> adminResponse = runSql(null,
new StringEntity("{\"query\": \"" + adminSql + "\", \"fetch_size\": 1" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
new StringEntity(query(adminSql).mode(mode).fetchSize(1).toString(), ContentType.APPLICATION_JSON), mode);
Map<String, Object> otherResponse = runSql(user,
new StringEntity("{\"query\": \"" + adminSql + "\", \"fetch_size\": 1" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
new StringEntity(query(adminSql).mode(mode).fetchSize(1).toString(), ContentType.APPLICATION_JSON), mode);
String adminCursor = (String) adminResponse.remove("cursor");
String otherCursor = (String) otherResponse.remove("cursor");
@ -83,12 +81,10 @@ public class RestSqlSecurityIT extends SqlSecurityTestCase {
assertNotNull(otherCursor);
assertResponse(adminResponse, otherResponse);
while (true) {
adminResponse = runSql(null,
new StringEntity("{\"cursor\": \"" + adminCursor + "\"" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
otherResponse = runSql(user,
new StringEntity("{\"cursor\": \"" + otherCursor + "\"" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
adminResponse = runSql(null, new StringEntity(cursor(adminCursor).mode(mode).toString(),
ContentType.APPLICATION_JSON), mode);
otherResponse = runSql(user, new StringEntity(cursor(otherCursor).mode(mode).toString(),
ContentType.APPLICATION_JSON), mode);
adminCursor = (String) adminResponse.remove("cursor");
otherCursor = (String) otherResponse.remove("cursor");
assertResponse(adminResponse, otherResponse);
@ -183,8 +179,7 @@ public class RestSqlSecurityIT extends SqlSecurityTestCase {
}
private static Map<String, Object> runSql(@Nullable String asUser, String mode, String sql) throws IOException {
return runSql(asUser, new StringEntity("{\"query\": \"" + sql + "\"" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
return runSql(asUser, new StringEntity(query(sql).mode(mode).toString(), ContentType.APPLICATION_JSON), mode);
}
private static Map<String, Object> runSql(@Nullable String asUser, HttpEntity entity, String mode) throws IOException {
@ -237,15 +232,13 @@ public class RestSqlSecurityIT extends SqlSecurityTestCase {
final String mode = randomMode();
Map<String, Object> adminResponse = RestActions.runSql(null,
new StringEntity("{\"query\": \"SELECT * FROM test\", \"fetch_size\": 1" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), mode);
new StringEntity(query("SELECT * FROM test").mode(mode).fetchSize(1).toString(), ContentType.APPLICATION_JSON), mode);
String cursor = (String) adminResponse.remove("cursor");
assertNotNull(cursor);
ResponseException e = expectThrows(ResponseException.class, () -> RestActions.runSql("full_access",
new StringEntity("{\"cursor\":\"" + cursor + "\"" + mode(mode) + version(mode) + "}", ContentType.APPLICATION_JSON),
mode));
new StringEntity(cursor(cursor).mode(mode).toString(), ContentType.APPLICATION_JSON), mode));
// TODO return a better error message for bad scrolls
assertThat(e.getMessage(), containsString("No search context found for id"));
assertEquals(404, e.getResponse().getStatusLine().getStatusCode());

View File

@ -30,10 +30,9 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.mode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.query;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.randomMode;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.toMap;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.version;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.SQL_QUERY_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.columnInfo;
@ -178,8 +177,7 @@ public class UserFunctionIT extends ESRestTestCase {
options.addHeader("es-security-runas-user", asUser);
request.setOptions(options);
}
request.setEntity(new StringEntity("{\"query\": \"" + sql + "\"" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON));
request.setEntity(new StringEntity(query(sql).mode(mode).toString(), ContentType.APPLICATION_JSON));
return toMap(client().performRequest(request), mode);
}

View File

@ -20,7 +20,7 @@ public class RestSqlIT extends RestSqlTestCase {
public void testErrorMessageForTranslatingQueryWithWhereEvaluatingToFalse() throws IOException {
index("{\"foo\":1}");
expectBadRequest(() -> runTranslateSql("{\"query\":\"SELECT * FROM test WHERE foo = 1 AND foo = 2\"}"),
expectBadRequest(() -> runTranslateSql(query("SELECT * FROM test WHERE foo = 1 AND foo = 2").toString()),
containsString("Cannot generate a query DSL for an SQL query that either its WHERE clause evaluates " +
"to FALSE or doesn't operate on a table (missing a FROM clause), sql statement: " +
"[SELECT * FROM test WHERE foo = 1 AND foo = 2]"));
@ -28,31 +28,29 @@ public class RestSqlIT extends RestSqlTestCase {
public void testErrorMessageForTranslatingQueryWithLocalExecution() throws IOException {
index("{\"foo\":1}");
expectBadRequest(() -> runTranslateSql("{\"query\":\"SELECT SIN(PI())\"}"),
expectBadRequest(() -> runTranslateSql(query("SELECT SIN(PI())").toString()),
containsString("Cannot generate a query DSL for an SQL query that either its WHERE clause evaluates " +
"to FALSE or doesn't operate on a table (missing a FROM clause), sql statement: [SELECT SIN(PI())]"));
}
public void testErrorMessageForTranslatingSQLCommandStatement() throws IOException {
index("{\"foo\":1}");
expectBadRequest(() -> runTranslateSql("{\"query\":\"SHOW FUNCTIONS\"}"),
expectBadRequest(() -> runTranslateSql(query("SHOW FUNCTIONS").toString()),
containsString("Cannot generate a query DSL for a special SQL command " +
"(e.g.: DESCRIBE, SHOW), sql statement: [SHOW FUNCTIONS]"));
}
public void testErrorMessageForInvalidParamDataType() throws IOException {
// proto.Mode not available
expectBadRequest(() -> runTranslateSql(
"{\"query\":\"SELECT null WHERE 0 = ? \"" + mode("odbc") + version("odbc") +
", \"params\":[{\"type\":\"invalid\", \"value\":\"irrelevant\"}]}"),
containsString("Invalid parameter data type [invalid]")
);
query("SELECT null WHERE 0 = ?").mode("odbc").params("[{\"type\":\"invalid\", \"value\":\"irrelevant\"}]").toString()),
containsString("Invalid parameter data type [invalid]"));
}
public void testErrorMessageForInvalidParamSpec() throws IOException {
expectBadRequest(() -> runTranslateSql(
"{\"query\":\"SELECT null WHERE 0 = ? \"" + mode("odbc") + version("odbc") +
", \"params\":[{\"type\":\"SHAPE\", \"value\":false}]}"),
containsString("Cannot cast value [false] of type [BOOLEAN] to parameter type [SHAPE]")
);
query("SELECT null WHERE 0 = ?").mode("odbc").params("[{\"type\":\"SHAPE\", \"value\":false}]").toString()),
containsString("Cannot cast value [false] of type [BOOLEAN] to parameter type [SHAPE]"));
}
}

View File

@ -14,6 +14,7 @@ import org.elasticsearch.common.Strings;
import org.elasticsearch.common.time.DateUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.sql.proto.Mode;
import org.elasticsearch.xpack.sql.qa.jdbc.JdbcIntegrationTestCase;
import org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase;
import org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase;
@ -51,11 +52,10 @@ public abstract class CustomDateFormatTestCase extends BaseRestSqlTestCase {
index(docs);
Request request = new Request("POST", RestSqlTestCase.SQL_QUERY_REST_ENDPOINT);
request.setEntity(new StringEntity("{\"query\":\"SELECT COUNT(*) AS c FROM test WHERE "
+ datesConditions.toString() + "\""
+ mode("plain")
+ ",\"time_zone\":\"" + zID + "\"" + "}", ContentType.APPLICATION_JSON));
final String query = "SELECT COUNT(*) AS c FROM test WHERE " + datesConditions.toString();
request.setEntity(new StringEntity(query(query).mode(Mode.PLAIN).timeZone(zID).toString(),
ContentType.APPLICATION_JSON));
Response response = client().performRequest(request);
String expectedJsonSnippet = "{\"columns\":[{\"name\":\"c\",\"type\":\"long\"}],\"rows\":[[";
try (InputStream content = response.getEntity().getContent()) {

View File

@ -838,8 +838,8 @@ public abstract class FieldExtractorTestCase extends BaseRestSqlTestCase {
Request request = new Request("POST", RestSqlTestCase.SQL_QUERY_REST_ENDPOINT);
request.addParameter("error_trace", "true");
request.addParameter("pretty", "true");
request.setEntity(new StringEntity("{\"query\":\"" + query + "\",\"mode\":\"plain\"}", ContentType.APPLICATION_JSON));
request.setEntity(new StringEntity(query(query).mode("plain").toString(), ContentType.APPLICATION_JSON));
return request;
}

View File

@ -31,8 +31,7 @@ import java.util.Map;
import static org.elasticsearch.xpack.sql.proto.Mode.CLI;
import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_QUERY_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.version;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.mode;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.query;
public abstract class SqlProtocolTestCase extends ESRestTestCase {
@ -131,12 +130,11 @@ public abstract class SqlProtocolTestCase extends ESRestTestCase {
@SuppressWarnings({ "unchecked" })
private void assertFloatingPointNumbersReturnTypes(Request request, Mode mode) throws IOException {
String requestContent = "{\"query\":\"SELECT "
String requestContent = query("SELECT "
+ "CAST(1234.34 AS REAL) AS float_positive,"
+ "CAST(-1234.34 AS REAL) AS float_negative,"
+ "1234567890123.34 AS double_positive,"
+ "-1234567890123.34 AS double_negative\""
+ mode(mode.toString()) + version(mode.toString()) + "}";
+ "-1234567890123.34 AS double_negative").mode(mode).toString();
request.setEntity(new StringEntity(requestContent, ContentType.APPLICATION_JSON));
Map<String, Object> map;
@ -219,7 +217,7 @@ public abstract class SqlProtocolTestCase extends ESRestTestCase {
private Map<String, Object> runSql(Mode mode, String sql, boolean columnar) throws IOException {
Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT);
String requestContent = "{\"query\":\"" + sql + "\"" + mode(mode.toString()) + version(mode.toString()) + "}";
String requestContent = query(sql).mode(mode).toString();
String format = randomFrom(XContentType.values()).name().toLowerCase(Locale.ROOT);
// add a client_id to the request

View File

@ -9,7 +9,6 @@ package org.elasticsearch.xpack.sql.qa.rest;
import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.cbor.CborXContent;
import org.elasticsearch.common.xcontent.json.JsonXContent;
@ -19,10 +18,120 @@ import org.elasticsearch.xpack.sql.proto.StringUtils;
import java.io.IOException;
import java.io.InputStream;
import java.util.Locale;
import java.util.Map;
import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CLIENT_ID_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME;
public abstract class BaseRestSqlTestCase extends ESRestTestCase {
public static class RequestObjectBuilder {
private StringBuilder request;
private final boolean isQuery;
private RequestObjectBuilder(String init, boolean isQuery) {
request = new StringBuilder(init);
this.isQuery = isQuery;
}
public static RequestObjectBuilder query(String query) {
return new RequestObjectBuilder(field(QUERY_NAME, query).substring(1), true);
}
public static RequestObjectBuilder cursor(String cursor) {
return new RequestObjectBuilder(field(CURSOR_NAME, cursor).substring(1), false);
}
public RequestObjectBuilder version(String version) {
request.append(field(VERSION_NAME, version));
return this;
}
public RequestObjectBuilder mode(Object m) {
String modeString = m.toString();
request.append(field(MODE_NAME, modeString));
if (isQuery) {
Mode mode = (m instanceof Mode) ? (Mode)m : Mode.fromString(modeString);
if (Mode.isDedicatedClient(mode)) {
version(Version.CURRENT.toString());
}
}
return this;
}
public RequestObjectBuilder fetchSize(Integer fetchSize) {
request.append(field(FETCH_SIZE_NAME, fetchSize));
return this;
}
public RequestObjectBuilder timeZone(String timeZone) {
request.append(field(TIME_ZONE_NAME, timeZone));
return this;
}
public RequestObjectBuilder clientId(String clientId) {
request.append(field(CLIENT_ID_NAME, clientId));
return this;
}
public RequestObjectBuilder filter(String filter) {
request.append(field(FILTER_NAME, filter));
return this;
}
public RequestObjectBuilder params(String params) {
request.append(field(PARAMS_NAME, params));
return this;
}
public RequestObjectBuilder columnar(Boolean columnar) {
request.append(field(COLUMNAR_NAME, columnar));
return this;
}
public RequestObjectBuilder binaryFormat(Boolean binaryFormat) {
request.append(field(BINARY_FORMAT_NAME, binaryFormat));
return this;
}
private static String field(String name, Object value) {
if (value == null) {
return StringUtils.EMPTY;
}
String field = "\"" + name + "\":";
if (value instanceof String) {
if (((String) value).isEmpty()) {
return StringUtils.EMPTY;
}
String lowerName = name.toLowerCase(Locale.ROOT);
if (lowerName.equals(PARAMS_NAME) || lowerName.equals(FILTER_NAME)) {
field += value;
} else {
field += "\"" + value + "\"";
}
} else {
field += value;
}
return "," + field;
}
@Override
public String toString() {
return "{" + request.toString() + "}";
}
}
protected void index(String... docs) throws IOException {
Request request = new Request("POST", "/test/_bulk");
request.addParameter("refresh", "true");
@ -35,20 +144,16 @@ public abstract class BaseRestSqlTestCase extends ESRestTestCase {
client().performRequest(request);
}
public static String mode(String mode) {
return Strings.isEmpty(mode) ? StringUtils.EMPTY : ",\"mode\":\"" + mode + "\"";
public static RequestObjectBuilder query(String query) {
return RequestObjectBuilder.query(query);
}
public static String version(String mode) {
Mode m = Mode.fromString(mode);
if (Mode.isDedicatedClient(m)) {
return ",\"version\":" + "\"" + Version.CURRENT.toString() + "\"";
}
return StringUtils.EMPTY;
public static RequestObjectBuilder cursor(String query) {
return RequestObjectBuilder.cursor(query);
}
public static String randomMode() {
return randomFrom(StringUtils.EMPTY, "jdbc", "plain");
return randomFrom(StringUtils.EMPTY, Mode.JDBC.toString(), Mode.PLAIN.toString());
}
/**

View File

@ -101,14 +101,10 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
client().performRequest(request);
boolean columnar = randomBoolean();
String sqlRequest =
"{\"query\":\""
+ " SELECT text, number, SQRT(number) AS s, SCORE()"
String sqlRequest = query(
"SELECT text, number, SQRT(number) AS s, SCORE()"
+ " FROM test"
+ " ORDER BY number, SCORE()\", "
+ "\"mode\":\"" + mode + "\""
+ version(mode)
+ ", \"fetch_size\":2" + columnarParameter(columnar) + "}";
+ " ORDER BY number, SCORE()").mode(mode).fetchSize(2).columnar(columnarValue(columnar)).toString();
Number value = xContentDependentFloatingNumberValue(mode, 1f);
String cursor = null;
@ -118,8 +114,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
response = runSql(new StringEntity(sqlRequest, ContentType.APPLICATION_JSON), "", mode);
} else {
columnar = randomBoolean();
response = runSql(new StringEntity("{\"cursor\":\"" + cursor + "\"" + mode(mode) + version(mode) +
columnarParameter(columnar) + "}",
response = runSql(new StringEntity(cursor(cursor).mode(mode).columnar(columnarValue(columnar)).toString(),
ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode);
}
@ -154,8 +149,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
} else {
expected.put("rows", emptyList());
}
assertResponse(expected, runSql(new StringEntity("{ \"cursor\":\"" + cursor + "\"" + mode(mode) + version(mode) +
columnarParameter(columnar) + "}",
assertResponse(expected, runSql(new StringEntity(cursor(cursor).mode(mode).columnar(columnarValue(columnar)).toString(),
ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}
@ -188,12 +182,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
ZoneId zoneId = randomZone();
String mode = randomMode();
String sqlRequest =
"{\"query\":\"SELECT DATE_PART('TZOFFSET', date) AS tz FROM test_date_timezone ORDER BY date\","
+ "\"time_zone\":\"" + zoneId.getId() + "\", "
+ "\"mode\":\"" + mode + "\""
+ version(mode)
+ ",\"fetch_size\":2}";
String sqlRequest = query("SELECT DATE_PART('TZOFFSET', date) AS tz FROM test_date_timezone ORDER BY date")
.timeZone(zoneId.getId()).mode(mode).fetchSize(2).toString();
String cursor = null;
for (int i = 0; i <= datetimes.length; i += 2) {
@ -204,8 +194,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
expected.put("columns", singletonList(columnInfo(mode, "tz", "integer", JDBCType.INTEGER, 11)));
response = runSql(new StringEntity(sqlRequest, ContentType.APPLICATION_JSON), "", mode);
} else {
response = runSql(new StringEntity("{\"cursor\":\"" + cursor + "\"" + mode(mode) + version(mode) + "}",
ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode);
response = runSql(new StringEntity(cursor(cursor).mode(mode).toString(), ContentType.APPLICATION_JSON), StringUtils.EMPTY,
mode);
}
List<Object> values = new ArrayList<>(2);
@ -220,7 +210,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
}
Map<String, Object> expected = new HashMap<>();
expected.put("rows", emptyList());
assertResponse(expected, runSql(new StringEntity("{ \"cursor\":\"" + cursor + "\"" + mode(mode) + version(mode) + "}",
assertResponse(expected, runSql(new StringEntity(cursor(cursor).mode(mode).toString(),
ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}
@ -417,9 +407,9 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
request.addParameter("error_trace", "true");
request.addParameter("pretty", "true");
request.addParameter("format", format);
request.setEntity(new StringEntity("{\"columnar\":true,\"query\":\"SELECT * FROM test\""
+ mode(randomValueOtherThan("jdbc", () -> randomMode())) + "}",
ContentType.APPLICATION_JSON));
request.setEntity(new StringEntity(query("SELECT * FROM test")
.mode(randomValueOtherThan(Mode.JDBC.toString(), BaseRestSqlTestCase::randomMode)).columnar(true).toString(),
ContentType.APPLICATION_JSON));
expectBadRequest(() -> {
client().performRequest(request);
return Collections.emptyMap();
@ -431,7 +421,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
String mode = randomMode();
Request request = new Request("POST", SQL_TRANSLATE_REST_ENDPOINT);
request.setEntity(new StringEntity("{\"columnar\":true,\"query\":\"SELECT * FROM test\"" + mode(mode) + version(mode) + "}",
request.setEntity(new StringEntity(query("SELECT * FROM test").mode(mode).columnar(true).toString(),
ContentType.APPLICATION_JSON));
expectBadRequest(() -> {
client().performRequest(request);
@ -470,20 +460,16 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
private Map<String, Object> runSql(String mode, String sql, String suffix, boolean columnar) throws IOException {
// put an explicit "columnar": false parameter or omit it altogether, it should make no difference
return runSql(new StringEntity("{\"query\":\"" + sql + "\"" + mode(mode) + version(mode) + columnarParameter(columnar) + "}",
return runSql(new StringEntity(query(sql).mode(mode).columnar(columnarValue(columnar)).toString(),
ContentType.APPLICATION_JSON), suffix, mode);
}
protected Map<String, Object> runTranslateSql(String sql) throws IOException {
return runSql(new StringEntity(sql, ContentType.APPLICATION_JSON), "/translate/", Mode.PLAIN.toString());
}
private String columnarParameter(boolean columnar) {
if (columnar == false && randomBoolean()) {
return "";
} else {
return ",\"columnar\":" + columnar;
}
private static Boolean columnarValue(boolean columnar) {
return columnar ? Boolean.TRUE : (randomBoolean() ? null : Boolean.FALSE);
}
protected Map<String, Object> runSql(HttpEntity sql, String suffix, String mode) throws IOException {
@ -573,8 +559,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
options.addHeader("Accept", randomFrom("*/*", "application/json"));
request.setOptions(options);
}
request.setEntity(new StringEntity("{\"query\":\"SELECT * FROM test\"" + mode("plain") + version("plain") +
columnarParameter(columnar) + "}",
request.setEntity(new StringEntity(
query("SELECT * FROM test").mode(Mode.PLAIN).columnar(columnarValue(columnar)).toString(),
ContentType.APPLICATION_JSON));
Response response = client().performRequest(request);
@ -592,8 +578,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
public void testBasicTranslateQuery() throws IOException {
index("{\"test\":\"test\"}", "{\"test\":\"test\"}");
Map<String, Object> response = runTranslateSql("{\"query\":\"SELECT * FROM test\"}");
Map<String, Object> response = runTranslateSql(query("SELECT * FROM test").toString());
assertEquals(1000, response.get("size"));
@SuppressWarnings("unchecked")
Map<String, Object> source = (Map<String, Object>) response.get("_source");
@ -610,8 +596,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
Map<String, Object> expected = new HashMap<>();
expected.put("columns", singletonList(columnInfo(mode, "test", "text", JDBCType.VARCHAR, Integer.MAX_VALUE)));
expected.put("rows", singletonList(singletonList("foo")));
assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT * FROM test\", " +
"\"filter\":{\"match\": {\"test\": \"foo\"}}" + mode(mode) + version(mode) + "}",
assertResponse(expected, runSql(
new StringEntity(query("SELECT * FROM test").mode(mode).filter("{\"match\": {\"test\": \"foo\"}}").toString(),
ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}
@ -634,16 +620,16 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
String params = mode.equals("jdbc") ? "{\"type\": \"integer\", \"value\": 10}, {\"type\": \"keyword\", \"value\": \"foo\"}" :
"10, \"foo\"";
assertResponse(expected, runSql(new StringEntity("{\"query\":\"SELECT test, ? param FROM test WHERE test = ?\", " +
"\"params\":[" + params + "]"
+ mode(mode) + version(mode) + columnarParameter(columnar) + "}", ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
assertResponse(expected, runSql(
new StringEntity(query("SELECT test, ? param FROM test WHERE test = ?").mode(mode).columnar(columnarValue(columnar))
.params("[" + params + "]").toString(), ContentType.APPLICATION_JSON), StringUtils.EMPTY, mode));
}
public void testBasicTranslateQueryWithFilter() throws IOException {
index("{\"test\":\"foo\"}",
"{\"test\":\"bar\"}");
Map<String, Object> response = runTranslateSql("{\"query\":\"SELECT * FROM test\", \"filter\":{\"match\": {\"test\": \"foo\"}}}");
Map<String, Object> response = runTranslateSql(query("SELECT * FROM test").filter("{\"match\": {\"test\": \"foo\"}}").toString());
assertEquals(response.get("size"), 1000);
@SuppressWarnings("unchecked")
@ -682,8 +668,8 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
index("{\"salary\":100}",
"{\"age\":20}");
Map<String, Object> response = runTranslateSql("{\"query\":\"SELECT avg(salary) FROM test GROUP BY abs(age) "
+ "HAVING avg(salary) > 50 LIMIT 10\"}");
Map<String, Object> response = runTranslateSql(
query("SELECT avg(salary) FROM test GROUP BY abs(age) HAVING avg(salary) > 50 LIMIT 10").toString());
assertEquals(response.get("size"), 0);
assertEquals(false, response.get("_source"));
@ -871,7 +857,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
}
index(docs);
String request = "{\"query\":\"SELECT text, number, number + 5 AS sum FROM test ORDER BY number\", \"fetch_size\":2}";
String request = query("SELECT text, number, number + 5 AS sum FROM test ORDER BY number").fetchSize(2).toString();
String cursor = null;
for (int i = 0; i < 20; i += 2) {
@ -879,7 +865,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
if (i == 0) {
response = runSqlAsText(StringUtils.EMPTY, new StringEntity(request, ContentType.APPLICATION_JSON), format);
} else {
response = runSqlAsText(StringUtils.EMPTY, new StringEntity("{\"cursor\":\"" + cursor + "\"}",
response = runSqlAsText(StringUtils.EMPTY, new StringEntity(cursor(cursor).toString(),
ContentType.APPLICATION_JSON), format);
}
@ -898,10 +884,10 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
}
Map<String, Object> expected = new HashMap<>();
expected.put("rows", emptyList());
assertResponse(expected, runSql(new StringEntity("{\"cursor\":\"" + cursor + "\"}", ContentType.APPLICATION_JSON),
StringUtils.EMPTY, Mode.PLAIN.toString()));
assertResponse(expected, runSql(new StringEntity(cursor(cursor).toString(), ContentType.APPLICATION_JSON),
StringUtils.EMPTY, Mode.PLAIN.toString()));
Map<String, Object> response = runSql(new StringEntity("{\"cursor\":\"" + cursor + "\"}", ContentType.APPLICATION_JSON),
Map<String, Object> response = runSql(new StringEntity(cursor(cursor).toString(), ContentType.APPLICATION_JSON),
"/close", Mode.PLAIN.toString());
assertEquals(true, response.get("succeeded"));
@ -909,7 +895,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
}
private Tuple<String, String> runSqlAsText(String sql, String accept) throws IOException {
return runSqlAsText(StringUtils.EMPTY, new StringEntity("{\"query\":\"" + sql + "\"}", ContentType.APPLICATION_JSON), accept);
return runSqlAsText(StringUtils.EMPTY, new StringEntity(query(sql).toString(), ContentType.APPLICATION_JSON), accept);
}
/**
@ -938,7 +924,7 @@ public abstract class RestSqlTestCase extends BaseRestSqlTestCase implements Err
Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT);
request.addParameter("error_trace", "true");
request.addParameter("format", format);
request.setJsonEntity("{\"query\":\"" + sql + "\"}");
request.setJsonEntity(query(sql).toString());
Response response = client().performRequest(request);
return new Tuple<>(

View File

@ -29,8 +29,7 @@ import java.util.Map;
import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_QUERY_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_STATS_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_TRANSLATE_REST_ENDPOINT;
import static org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase.version;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.mode;
import static org.elasticsearch.xpack.sql.qa.rest.RestSqlTestCase.query;
public abstract class RestSqlUsageTestCase extends ESRestTestCase {
private List<IndexDocument> testData = Arrays.asList(
@ -251,10 +250,10 @@ public abstract class RestSqlUsageTestCase extends ESRestTestCase {
options.addHeader("Accept", randomFrom("*/*", "application/json"));
request.setOptions(options);
}
request.setEntity(new StringEntity("{\"query\":\"" + sql + "\"}", ContentType.APPLICATION_JSON));
request.setEntity(new StringEntity(query(sql).toString(), ContentType.APPLICATION_JSON));
client().performRequest(request);
}
private void runSql(String sql) throws IOException {
Mode mode = Mode.PLAIN;
if (clientType.equals(ClientType.JDBC.toString())) {
@ -294,8 +293,8 @@ public abstract class RestSqlUsageTestCase extends ESRestTestCase {
options.addHeader("Accept", randomFrom("*/*", "application/json"));
request.setOptions(options);
}
request.setEntity(new StringEntity("{\"query\":\"" + sql + "\"" + mode(mode) + version(mode) +
(ignoreClientType ? StringUtils.EMPTY : ",\"client_id\":\"" + restClient + "\"") + "}", ContentType.APPLICATION_JSON));
request.setEntity(new StringEntity(query(sql).mode(mode).clientId(ignoreClientType ? StringUtils.EMPTY : restClient).toString(),
ContentType.APPLICATION_JSON));
client().performRequest(request);
}

View File

@ -38,6 +38,17 @@ import java.util.Objects;
import java.util.function.Supplier;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.xpack.sql.proto.Protocol.CLIENT_ID_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.REQUEST_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME;
/**
* Base class for requests that contain sql queries (Query and Translate)
@ -53,18 +64,17 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
private QueryBuilder filter = null;
private List<SqlTypedParamValue> params = Collections.emptyList();
// TODO: define all REST request object field names in a protocol class as unique source
static final ParseField QUERY = new ParseField("query");
static final ParseField CURSOR = new ParseField("cursor");
static final ParseField PARAMS = new ParseField("params");
static final ParseField TIME_ZONE = new ParseField("time_zone");
static final ParseField FETCH_SIZE = new ParseField("fetch_size");
static final ParseField REQUEST_TIMEOUT = new ParseField("request_timeout");
static final ParseField PAGE_TIMEOUT = new ParseField("page_timeout");
static final ParseField FILTER = new ParseField("filter");
static final ParseField MODE = new ParseField("mode");
static final ParseField CLIENT_ID = new ParseField("client_id");
static final ParseField CLIENT_VERSION = new ParseField("version");
static final ParseField QUERY = new ParseField(QUERY_NAME);
static final ParseField CURSOR = new ParseField(CURSOR_NAME);
static final ParseField PARAMS = new ParseField(PARAMS_NAME);
static final ParseField TIME_ZONE = new ParseField(TIME_ZONE_NAME);
static final ParseField FETCH_SIZE = new ParseField(FETCH_SIZE_NAME);
static final ParseField REQUEST_TIMEOUT = new ParseField(REQUEST_TIMEOUT_NAME);
static final ParseField PAGE_TIMEOUT = new ParseField(PAGE_TIMEOUT_NAME);
static final ParseField FILTER = new ParseField(FILTER_NAME);
static final ParseField MODE = new ParseField(MODE_NAME);
static final ParseField CLIENT_ID = new ParseField(CLIENT_ID_NAME);
static final ParseField VERSION = new ParseField(VERSION_NAME);
public AbstractSqlQueryRequest() {
super();
@ -89,14 +99,14 @@ public abstract class AbstractSqlQueryRequest extends AbstractSqlRequest impleme
parser.declareString(AbstractSqlQueryRequest::query, QUERY);
parser.declareString((request, mode) -> request.mode(Mode.fromString(mode)), MODE);
parser.declareString(AbstractSqlRequest::clientId, CLIENT_ID);
parser.declareString(AbstractSqlRequest::version, CLIENT_VERSION);
parser.declareString(AbstractSqlRequest::version, VERSION);
parser.declareField(AbstractSqlQueryRequest::params, AbstractSqlQueryRequest::parseParams, PARAMS, ValueType.VALUE_ARRAY);
parser.declareString((request, zoneId) -> request.zoneId(ZoneId.of(zoneId)), TIME_ZONE);
parser.declareInt(AbstractSqlQueryRequest::fetchSize, FETCH_SIZE);
parser.declareString((request, timeout) -> request.requestTimeout(TimeValue.parseTimeValue(timeout, Protocol.REQUEST_TIMEOUT,
"request_timeout")), REQUEST_TIMEOUT);
REQUEST_TIMEOUT_NAME)), REQUEST_TIMEOUT);
parser.declareString(
(request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, "page_timeout")),
(request, timeout) -> request.pageTimeout(TimeValue.parseTimeValue(timeout, Protocol.PAGE_TIMEOUT, PAGE_TIMEOUT_NAME)),
PAGE_TIMEOUT);
parser.declareObject(AbstractSqlQueryRequest::filter,
(p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p), FILTER);

View File

@ -21,7 +21,7 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.CLIENT_ID;
import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.CLIENT_VERSION;
import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.VERSION;
import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.CURSOR;
import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.MODE;
@ -43,7 +43,7 @@ public class SqlClearCursorRequest extends AbstractSqlRequest {
PARSER.declareString(constructorArg(), CURSOR);
PARSER.declareString(optionalConstructorArg(), MODE);
PARSER.declareString(optionalConstructorArg(), CLIENT_ID);
PARSER.declareString(optionalConstructorArg(), CLIENT_VERSION);
PARSER.declareString(optionalConstructorArg(), VERSION);
}
private String cursor;

View File

@ -25,16 +25,20 @@ import java.util.List;
import java.util.Objects;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME;
/**
* Request to perform an sql query
*/
public class SqlQueryRequest extends AbstractSqlQueryRequest {
private static final ObjectParser<SqlQueryRequest, Void> PARSER = objectParser(SqlQueryRequest::new);
static final ParseField COLUMNAR = new ParseField("columnar");
static final ParseField FIELD_MULTI_VALUE_LENIENCY = new ParseField("field_multi_value_leniency");
static final ParseField INDEX_INCLUDE_FROZEN = new ParseField("index_include_frozen");
static final ParseField BINARY_COMMUNICATION = new ParseField("binary_format");
static final ParseField COLUMNAR = new ParseField(COLUMNAR_NAME);
static final ParseField FIELD_MULTI_VALUE_LENIENCY = new ParseField(FIELD_MULTI_VALUE_LENIENCY_NAME);
static final ParseField INDEX_INCLUDE_FROZEN = new ParseField(INDEX_INCLUDE_FROZEN_NAME);
static final ParseField BINARY_COMMUNICATION = new ParseField(BINARY_FORMAT_NAME);
static {
PARSER.declareString(SqlQueryRequest::cursor, CURSOR);

View File

@ -33,6 +33,23 @@ import java.util.function.Supplier;
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilter;
import static org.elasticsearch.xpack.sql.action.SqlTestUtils.randomFilterOrNull;
import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CLIENT_ID_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_TYPE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_VALUE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.REQUEST_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME;
import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS;
public class SqlQueryRequestTests extends AbstractWireSerializingTestCase<SqlQueryRequest> {
@ -170,22 +187,22 @@ public class SqlQueryRequestTests extends AbstractWireSerializingTestCase<SqlQue
private static void toXContent(SqlQueryRequest request, XContentBuilder builder) throws IOException {
builder.startObject();
if (request.query() != null) {
builder.field("query", request.query());
builder.field(QUERY_NAME, request.query());
}
builder.field("mode", request.mode().toString());
builder.field(MODE_NAME, request.mode().toString());
if (request.clientId() != null) {
builder.field("client_id", request.clientId());
builder.field(CLIENT_ID_NAME, request.clientId());
}
if (request.version() != null) {
builder.field("version", request.version().toString());
builder.field(VERSION_NAME, request.version().toString());
}
if (request.params() != null && request.params().isEmpty() == false) {
builder.startArray("params");
builder.startArray(PARAMS_NAME);
for (SqlTypedParamValue val : request.params()) {
if (Mode.isDriver(request.mode())) {
builder.startObject();
builder.field("type", val.type);
builder.field("value", val.value);
builder.field(PARAMS_TYPE_NAME, val.type);
builder.field(PARAMS_VALUE_NAME, val.value);
builder.endObject();
} else {
builder.value(val.value);
@ -194,35 +211,35 @@ public class SqlQueryRequestTests extends AbstractWireSerializingTestCase<SqlQue
builder.endArray();
}
if (request.zoneId() != null) {
builder.field("time_zone", request.zoneId().getId());
builder.field(TIME_ZONE_NAME, request.zoneId().getId());
}
if (request.fetchSize() != Protocol.FETCH_SIZE) {
builder.field("fetch_size", request.fetchSize());
builder.field(FETCH_SIZE_NAME, request.fetchSize());
}
if (request.requestTimeout() != Protocol.REQUEST_TIMEOUT) {
builder.field("request_timeout", request.requestTimeout().getStringRep());
builder.field(REQUEST_TIMEOUT_NAME, request.requestTimeout().getStringRep());
}
if (request.pageTimeout() != Protocol.PAGE_TIMEOUT) {
builder.field("page_timeout", request.pageTimeout().getStringRep());
builder.field(PAGE_TIMEOUT_NAME, request.pageTimeout().getStringRep());
}
if (request.filter() != null) {
builder.field("filter");
builder.field(FILTER_NAME);
request.filter().toXContent(builder, ToXContent.EMPTY_PARAMS);
}
if (request.columnar() != null) {
builder.field("columnar", request.columnar());
builder.field(COLUMNAR_NAME, request.columnar());
}
if (request.fieldMultiValueLeniency()) {
builder.field("field_multi_value_leniency", request.fieldMultiValueLeniency());
builder.field(FIELD_MULTI_VALUE_LENIENCY_NAME, request.fieldMultiValueLeniency());
}
if (request.indexIncludeFrozen()) {
builder.field("index_include_frozen", request.indexIncludeFrozen());
builder.field(INDEX_INCLUDE_FROZEN_NAME, request.indexIncludeFrozen());
}
if (request.binaryCommunication() != null) {
builder.field("binary_format", request.binaryCommunication());
builder.field(BINARY_FORMAT_NAME, request.binaryCommunication());
}
if (request.cursor() != null) {
builder.field("cursor", request.cursor());
builder.field(CURSOR_NAME, request.cursor());
}
builder.endObject();
}

View File

@ -47,8 +47,17 @@ import java.util.Properties;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.REQUEST_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME;
public class HttpClientRequestTests extends ESTestCase {
private static RawRequestMockWebServer webServer = new RawRequestMockWebServer();
private static final Logger logger = LogManager.getLogger(HttpClientRequestTests.class);
@ -106,16 +115,16 @@ public class HttpClientRequestTests extends ESTestCase {
BytesReference bytesRef = recordedRequest.getBodyAsBytes();
Map<String, Object> reqContent = XContentHelper.convertToMap(bytesRef, false, xContentType).v2();
assertTrue(((String) reqContent.get("mode")).equalsIgnoreCase(Mode.CLI.toString()));
assertEquals(isBinary, reqContent.get("binary_format"));
assertEquals(Boolean.FALSE, reqContent.get("columnar"));
assertEquals(fetchSize, reqContent.get("fetch_size"));
assertEquals(query, reqContent.get("query"));
assertEquals("90000ms", reqContent.get("request_timeout"));
assertEquals("45000ms", reqContent.get("page_timeout"));
assertEquals("Z", reqContent.get("time_zone"));
assertTrue(((String) reqContent.get(MODE_NAME)).equalsIgnoreCase(Mode.CLI.toString()));
assertEquals(isBinary, reqContent.get(BINARY_FORMAT_NAME));
assertEquals(Boolean.FALSE, reqContent.get(COLUMNAR_NAME));
assertEquals(fetchSize, reqContent.get(FETCH_SIZE_NAME));
assertEquals(query, reqContent.get(QUERY_NAME));
assertEquals("90000ms", reqContent.get(REQUEST_TIMEOUT_NAME));
assertEquals("45000ms", reqContent.get(PAGE_TIMEOUT_NAME));
assertEquals("Z", reqContent.get(TIME_ZONE_NAME));
prepareMockResponse();
try {
// we don't care what the cursor is, because the ES node that will actually handle the request (as in running an ES search)
@ -131,13 +140,13 @@ public class HttpClientRequestTests extends ESTestCase {
bytesRef = recordedRequest.getBodyAsBytes();
reqContent = XContentHelper.convertToMap(bytesRef, false, xContentType).v2();
assertTrue(((String) reqContent.get("mode")).equalsIgnoreCase(Mode.CLI.toString()));
assertEquals(isBinary, reqContent.get("binary_format"));
assertEquals("90000ms", reqContent.get("request_timeout"));
assertEquals("45000ms", reqContent.get("page_timeout"));
assertTrue(((String) reqContent.get(MODE_NAME)).equalsIgnoreCase(Mode.CLI.toString()));
assertEquals(isBinary, reqContent.get(BINARY_FORMAT_NAME));
assertEquals("90000ms", reqContent.get(REQUEST_TIMEOUT_NAME));
assertEquals("45000ms", reqContent.get(PAGE_TIMEOUT_NAME));
}
private void assertBinaryRequestForDrivers(boolean isBinary, XContentType xContentType) throws URISyntaxException {
String url = "http://" + webServer.getHostName() + ":" + webServer.getPort();
String query = randomAlphaOfLength(256);
@ -176,13 +185,13 @@ public class HttpClientRequestTests extends ESTestCase {
BytesReference bytesRef = recordedRequest.getBodyAsBytes();
Map<String, Object> reqContent = XContentHelper.convertToMap(bytesRef, false, xContentType).v2();
assertTrue(((String) reqContent.get("mode")).equalsIgnoreCase(mode.toString()));
assertEquals(isBinary, reqContent.get("binary_format"));
assertEquals(query, reqContent.get("query"));
assertEquals("Z", reqContent.get("time_zone"));
assertTrue(((String) reqContent.get(MODE_NAME)).equalsIgnoreCase(mode.toString()));
assertEquals(isBinary, reqContent.get(BINARY_FORMAT_NAME));
assertEquals(query, reqContent.get(QUERY_NAME));
assertEquals("Z", reqContent.get(TIME_ZONE_NAME));
}
private void prepareMockResponse() {
webServer.enqueue(new Response().setResponseCode(200).addHeader("Content-Type", "application/json").setBody("{\"rows\":[]}"));
}

View File

@ -14,6 +14,32 @@ import java.time.ZoneId;
* Sql protocol defaults and end-points shared between JDBC and REST protocol implementations
*/
public final class Protocol {
/**
* The attribute names used in the protocol request/response objects.
*/
// requests
public static final String QUERY_NAME = "query";
public static final String CURSOR_NAME = "cursor"; /* request/reply */
public static final String TIME_ZONE_NAME = "time_zone";
public static final String FETCH_SIZE_NAME = "fetch_size";
public static final String REQUEST_TIMEOUT_NAME = "request_timeout";
public static final String PAGE_TIMEOUT_NAME = "page_timeout";
public static final String FILTER_NAME = "filter";
public static final String MODE_NAME = "mode";
public static final String CLIENT_ID_NAME = "client_id";
public static final String VERSION_NAME = "version";
public static final String COLUMNAR_NAME = "columnar";
public static final String BINARY_FORMAT_NAME = "binary_format";
public static final String FIELD_MULTI_VALUE_LENIENCY_NAME = "field_multi_value_leniency";
public static final String INDEX_INCLUDE_FROZEN_NAME = "index_include_frozen";
// params
public static final String PARAMS_NAME = "params";
public static final String PARAMS_TYPE_NAME = "type";
public static final String PARAMS_VALUE_NAME = "value";
// responses
public static final String COLUMNS_NAME = "columns";
public static final String ROWS_NAME = "rows";
public static final ZoneId TIME_ZONE = ZoneId.of("Z");
/**

View File

@ -17,6 +17,22 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CLIENT_ID_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.REQUEST_TIMEOUT_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME;
/**
* Sql query request for JDBC/CLI client
*/
@ -174,52 +190,52 @@ public class SqlQueryRequest extends AbstractSqlRequest {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
if (query != null) {
builder.field("query", query);
builder.field(QUERY_NAME, query);
}
builder.field("mode", mode().toString());
builder.field(MODE_NAME, mode().toString());
if (clientId() != null) {
builder.field("client_id", clientId());
builder.field(CLIENT_ID_NAME, clientId());
}
if (version() != null) {
builder.field("version", version().toString());
builder.field(VERSION_NAME, version().toString());
}
if (this.params != null && this.params.isEmpty() == false) {
builder.startArray("params");
builder.startArray(PARAMS_NAME);
for (SqlTypedParamValue val : this.params) {
val.toXContent(builder, params);
}
builder.endArray();
}
if (zoneId != null) {
builder.field("time_zone", zoneId.getId());
builder.field(TIME_ZONE_NAME, zoneId.getId());
}
if (fetchSize != Protocol.FETCH_SIZE) {
builder.field("fetch_size", fetchSize);
builder.field(FETCH_SIZE_NAME, fetchSize);
}
if (requestTimeout != Protocol.REQUEST_TIMEOUT) {
builder.field("request_timeout", requestTimeout.getStringRep());
builder.field(REQUEST_TIMEOUT_NAME, requestTimeout.getStringRep());
}
if (pageTimeout != Protocol.PAGE_TIMEOUT) {
builder.field("page_timeout", pageTimeout.getStringRep());
builder.field(PAGE_TIMEOUT_NAME, pageTimeout.getStringRep());
}
if (filter != null) {
builder.field("filter");
builder.field(FILTER_NAME);
filter.toXContent(builder, params);
}
if (columnar != null) {
builder.field("columnar", columnar);
builder.field(COLUMNAR_NAME, columnar);
}
if (fieldMultiValueLeniency) {
builder.field("field_multi_value_leniency", fieldMultiValueLeniency);
builder.field(FIELD_MULTI_VALUE_LENIENCY_NAME, fieldMultiValueLeniency);
}
if (indexIncludeFrozen) {
builder.field("index_include_frozen", indexIncludeFrozen);
builder.field(INDEX_INCLUDE_FROZEN_NAME, indexIncludeFrozen);
}
if (binaryCommunication != null) {
builder.field("binary_format", binaryCommunication);
builder.field(BINARY_FORMAT_NAME, binaryCommunication);
}
if (cursor != null) {
builder.field("cursor", cursor);
builder.field(CURSOR_NAME, cursor);
}
return builder;
}

View File

@ -18,6 +18,9 @@ import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNS_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME;
import static org.elasticsearch.xpack.sql.proto.Protocol.ROWS_NAME;
/**
* Response to perform an sql query for JDBC/CLI client
@ -31,9 +34,9 @@ public class SqlQueryResponse {
(List<ColumnInfo>) objects[1],
(List<List<Object>>) objects[2]));
public static final ParseField CURSOR = new ParseField("cursor");
public static final ParseField COLUMNS = new ParseField("columns");
public static final ParseField ROWS = new ParseField("rows");
public static final ParseField CURSOR = new ParseField(CURSOR_NAME);
public static final ParseField COLUMNS = new ParseField(COLUMNS_NAME);
public static final ParseField ROWS = new ParseField(ROWS_NAME);
static {
PARSER.declareString(optionalConstructorArg(), CURSOR);