diff --git a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/AbstractCassandraProcessor.java b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/AbstractCassandraProcessor.java index 0c53a35a80..002ec27baf 100644 --- a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/AbstractCassandraProcessor.java +++ b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/AbstractCassandraProcessor.java @@ -28,14 +28,6 @@ import com.datastax.driver.core.Session; import com.datastax.driver.core.TypeCodec; import com.datastax.driver.core.exceptions.AuthenticationException; import com.datastax.driver.core.exceptions.NoHostAvailableException; -import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import javax.net.ssl.SSLContext; import com.datastax.driver.extras.codecs.arrays.ObjectArrayCodec; import org.apache.avro.Schema; import org.apache.avro.SchemaBuilder; @@ -56,6 +48,15 @@ import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.security.util.ClientAuth; import org.apache.nifi.ssl.SSLContextService; +import javax.net.ssl.SSLContext; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + /** * AbstractCassandraProcessor is a base class for Cassandra processors and contains logic and variables common to most * processors integrating with Apache Cassandra. @@ -181,7 +182,6 @@ public abstract class AbstractCassandraProcessor extends AbstractProcessor { descriptors.add(USERNAME); descriptors.add(PASSWORD); descriptors.add(CONSISTENCY_LEVEL); - descriptors.add(COMPRESSION_TYPE); descriptors.add(CHARSET); } @@ -209,12 +209,12 @@ public abstract class AbstractCassandraProcessor extends AbstractProcessor { if (connectionProviderIsSet && contactPointsIsSet) { results.add(new ValidationResult.Builder().subject("Cassandra configuration").valid(false).explanation("both " + CONNECTION_PROVIDER_SERVICE.getDisplayName() + - " and processor level Cassandra configuration cannot be provided at the same time.").build()); + " and processor level Cassandra configuration cannot be provided at the same time.").build()); } if (!connectionProviderIsSet && !contactPointsIsSet) { results.add(new ValidationResult.Builder().subject("Cassandra configuration").valid(false).explanation("either " + CONNECTION_PROVIDER_SERVICE.getDisplayName() + - " or processor level Cassandra configuration has to be provided.").build()); + " or processor level Cassandra configuration has to be provided.").build()); } return results; @@ -224,7 +224,6 @@ public abstract class AbstractCassandraProcessor extends AbstractProcessor { public void onScheduled(ProcessContext context) { final boolean connectionProviderIsSet = context.getProperty(CONNECTION_PROVIDER_SERVICE).isSet(); - // Register codecs registerAdditionalCodecs(); if (connectionProviderIsSet) { @@ -386,7 +385,6 @@ public abstract class AbstractCassandraProcessor extends AbstractProcessor { } else if (dataType.equals(DataType.timestamp())) { return row.getTimestamp(i); - } else if (dataType.equals(DataType.date())) { return row.getDate(i); diff --git a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/QueryCassandra.java b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/QueryCassandra.java index 6212082861..0dac574fc1 100644 --- a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/QueryCassandra.java +++ b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/main/java/org/apache/nifi/processors/cassandra/QueryCassandra.java @@ -19,7 +19,6 @@ package org.apache.nifi.processors.cassandra; import com.datastax.driver.core.ColumnDefinitions; import com.datastax.driver.core.DataType; import com.datastax.driver.core.ResultSet; -import com.datastax.driver.core.ResultSetFuture; import com.datastax.driver.core.Row; import com.datastax.driver.core.Session; import com.datastax.driver.core.exceptions.NoHostAvailableException; @@ -58,7 +57,6 @@ import org.apache.nifi.processor.io.OutputStreamCallback; import org.apache.nifi.processor.util.StandardValidators; import org.apache.nifi.util.StopWatch; -import java.io.BufferedOutputStream; import java.io.IOException; import java.io.OutputStream; import java.nio.charset.Charset; @@ -70,6 +68,7 @@ import java.util.Date; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.TimeZone; @@ -123,6 +122,30 @@ public class QueryCassandra extends AbstractCassandraProcessor { .addValidator(StandardValidators.INTEGER_VALIDATOR) .build(); + public static final PropertyDescriptor MAX_ROWS_PER_FLOW_FILE = new PropertyDescriptor.Builder() + .name("Max Rows Per Flow File") + .description("The maximum number of result rows that will be included in a single FlowFile. This will allow you to break up very large " + + "result sets into multiple FlowFiles. If the value specified is zero, then all rows are returned in a single FlowFile.") + .defaultValue("0") + .required(true) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .addValidator(StandardValidators.INTEGER_VALIDATOR) + .build(); + + public static final PropertyDescriptor OUTPUT_BATCH_SIZE = new PropertyDescriptor.Builder() + .name("qdbt-output-batch-size") + .displayName("Output Batch Size") + .description("The number of output FlowFiles to queue before committing the process session. When set to zero, the session will be committed when all result set rows " + + "have been processed and the output FlowFiles are ready for transfer to the downstream relationship. For large result sets, this can cause a large burst of FlowFiles " + + "to be transferred at the end of processor execution. If this property is set, then when the specified number of FlowFiles are ready for transfer, then the session will " + + "be committed, thus releasing the FlowFiles to the downstream relationship. NOTE: The maxvalue.* and fragment.count attributes will not be set on FlowFiles when this " + + "property is set.") + .defaultValue("0") + .required(true) + .addValidator(StandardValidators.NON_NEGATIVE_INTEGER_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.VARIABLE_REGISTRY) + .build(); + public static final PropertyDescriptor OUTPUT_FORMAT = new PropertyDescriptor.Builder() .name("Output Format") .description("The format to which the result rows will be converted. If JSON is selected, the output will " @@ -165,6 +188,8 @@ public class QueryCassandra extends AbstractCassandraProcessor { _propertyDescriptors.add(CQL_SELECT_QUERY); _propertyDescriptors.add(QUERY_TIMEOUT); _propertyDescriptors.add(FETCH_SIZE); + _propertyDescriptors.add(MAX_ROWS_PER_FLOW_FILE); + _propertyDescriptors.add(OUTPUT_BATCH_SIZE); _propertyDescriptors.add(OUTPUT_FORMAT); _propertyDescriptors.add(TIMESTAMP_FORMAT_PATTERN); propertyDescriptors = Collections.unmodifiableList(_propertyDescriptors); @@ -202,6 +227,7 @@ public class QueryCassandra extends AbstractCassandraProcessor { @Override public void onTrigger(final ProcessContext context, final ProcessSession session) throws ProcessException { FlowFile fileToProcess = null; + if (context.hasIncomingConnection()) { fileToProcess = session.get(); @@ -217,60 +243,90 @@ public class QueryCassandra extends AbstractCassandraProcessor { final String selectQuery = context.getProperty(CQL_SELECT_QUERY).evaluateAttributeExpressions(fileToProcess).getValue(); final long queryTimeout = context.getProperty(QUERY_TIMEOUT).evaluateAttributeExpressions(fileToProcess).asTimePeriod(TimeUnit.MILLISECONDS); final String outputFormat = context.getProperty(OUTPUT_FORMAT).getValue(); + final long maxRowsPerFlowFile = context.getProperty(MAX_ROWS_PER_FLOW_FILE).evaluateAttributeExpressions().asInteger(); + final long outputBatchSize = context.getProperty(OUTPUT_BATCH_SIZE).evaluateAttributeExpressions().asInteger(); final Charset charset = Charset.forName(context.getProperty(CHARSET).evaluateAttributeExpressions(fileToProcess).getValue()); final StopWatch stopWatch = new StopWatch(true); - if (fileToProcess == null) { - fileToProcess = session.create(); - } - try { // The documentation for the driver recommends the session remain open the entire time the processor is running // and states that it is thread-safe. This is why connectionSession is not in a try-with-resources. final Session connectionSession = cassandraSession.get(); - final ResultSetFuture queryFuture = connectionSession.executeAsync(selectQuery); + final ResultSet resultSet; + + if (queryTimeout > 0) { + resultSet = connectionSession.execute(selectQuery, queryTimeout, TimeUnit.MILLISECONDS); + }else{ + resultSet = connectionSession.execute(selectQuery); + } final AtomicLong nrOfRows = new AtomicLong(0L); - fileToProcess = session.write(fileToProcess, new OutputStreamCallback() { - @Override - public void process(final OutputStream rawOut) throws IOException { - try (final OutputStream out = new BufferedOutputStream(rawOut)) { - logger.debug("Executing CQL query {}", new Object[]{selectQuery}); - final ResultSet resultSet; - if (queryTimeout > 0) { - resultSet = queryFuture.getUninterruptibly(queryTimeout, TimeUnit.MILLISECONDS); - if (AVRO_FORMAT.equals(outputFormat)) { - nrOfRows.set(convertToAvroStream(resultSet, out, queryTimeout, TimeUnit.MILLISECONDS)); - } else if (JSON_FORMAT.equals(outputFormat)) { - nrOfRows.set(convertToJsonStream(Optional.of(context), resultSet, out, charset, queryTimeout, TimeUnit.MILLISECONDS)); - } - } else { - resultSet = queryFuture.getUninterruptibly(); - if (AVRO_FORMAT.equals(outputFormat)) { - nrOfRows.set(convertToAvroStream(resultSet, out, 0, null)); - } else if (JSON_FORMAT.equals(outputFormat)) { - nrOfRows.set(convertToJsonStream(Optional.of(context), resultSet, out, charset, 0, null)); - } - } + long flowFileCount = 0; - } catch (final TimeoutException | InterruptedException | ExecutionException e) { - throw new ProcessException(e); + if(fileToProcess == null) { + fileToProcess = session.create(); + } + + while(true) { + + fileToProcess = session.write(fileToProcess, new OutputStreamCallback() { + @Override + public void process(final OutputStream out) throws IOException { + try { + logger.debug("Executing CQL query {}", new Object[]{selectQuery}); + if (queryTimeout > 0) { + if (AVRO_FORMAT.equals(outputFormat)) { + nrOfRows.set(convertToAvroStream(resultSet, maxRowsPerFlowFile, + out, queryTimeout, TimeUnit.MILLISECONDS)); + } else if (JSON_FORMAT.equals(outputFormat)) { + nrOfRows.set(convertToJsonStream(resultSet, maxRowsPerFlowFile, + out, charset, queryTimeout, TimeUnit.MILLISECONDS)); + } + } else { + if (AVRO_FORMAT.equals(outputFormat)) { + nrOfRows.set(convertToAvroStream(resultSet, maxRowsPerFlowFile, + out, 0, null)); + } else if (JSON_FORMAT.equals(outputFormat)) { + nrOfRows.set(convertToJsonStream(resultSet, maxRowsPerFlowFile, + out, charset, 0, null)); + } + } + } catch (final TimeoutException | InterruptedException | ExecutionException e) { + throw new ProcessException(e); + } + } + }); + + // set attribute how many rows were selected + fileToProcess = session.putAttribute(fileToProcess, RESULT_ROW_COUNT, String.valueOf(nrOfRows.get())); + + // set mime.type based on output format + fileToProcess = session.putAttribute(fileToProcess, CoreAttributes.MIME_TYPE.key(), + JSON_FORMAT.equals(outputFormat) ? "application/json" : "application/avro-binary"); + + if (logger.isDebugEnabled()) { + logger.info("{} contains {} records; transferring to 'success'", + new Object[]{fileToProcess, nrOfRows.get()}); + } + session.getProvenanceReporter().modifyContent(fileToProcess, "Retrieved " + nrOfRows.get() + " rows", + stopWatch.getElapsed(TimeUnit.MILLISECONDS)); + session.transfer(fileToProcess, REL_SUCCESS); + + if (outputBatchSize > 0) { + flowFileCount++; + + if (flowFileCount == outputBatchSize) { + session.commitAsync(); + flowFileCount = 0; +// fileToProcess = session.create(); } } - }); - - // set attribute how many rows were selected - fileToProcess = session.putAttribute(fileToProcess, RESULT_ROW_COUNT, String.valueOf(nrOfRows.get())); - - // set mime.type based on output format - fileToProcess = session.putAttribute(fileToProcess, CoreAttributes.MIME_TYPE.key(), - JSON_FORMAT.equals(outputFormat) ? "application/json" : "application/avro-binary"); - - logger.info("{} contains {} Avro records; transferring to 'success'", - new Object[]{fileToProcess, nrOfRows.get()}); - session.getProvenanceReporter().modifyContent(fileToProcess, "Retrieved " + nrOfRows.get() + " rows", - stopWatch.getElapsed(TimeUnit.MILLISECONDS)); - session.transfer(fileToProcess, REL_SUCCESS); + resultSet.fetchMoreResults().get(); + if (resultSet.isExhausted()) { + break; + } + fileToProcess = session.create(); + } } catch (final NoHostAvailableException nhae) { getLogger().error("No host in the Cassandra cluster can be contacted successfully to execute this query", nhae); @@ -279,11 +335,16 @@ public class QueryCassandra extends AbstractCassandraProcessor { // cap the error limit at 10, format the messages, and don't include the stack trace (it is displayed by the // logger message above). getLogger().error(nhae.getCustomMessage(10, true, false)); + if (fileToProcess == null) { + fileToProcess = session.create(); + } fileToProcess = session.penalize(fileToProcess); session.transfer(fileToProcess, REL_RETRY); - } catch (final QueryExecutionException qee) { logger.error("Cannot execute the query with the requested consistency level successfully", qee); + if (fileToProcess == null) { + fileToProcess = session.create(); + } fileToProcess = session.penalize(fileToProcess); session.transfer(fileToProcess, REL_RETRY); @@ -291,29 +352,64 @@ public class QueryCassandra extends AbstractCassandraProcessor { if (context.hasIncomingConnection()) { logger.error("The CQL query {} is invalid due to syntax error, authorization issue, or another " + "validation problem; routing {} to failure", - new Object[]{selectQuery, fileToProcess}, qve); + selectQuery, fileToProcess, qve); + + if (fileToProcess == null) { + fileToProcess = session.create(); + } fileToProcess = session.penalize(fileToProcess); session.transfer(fileToProcess, REL_FAILURE); } else { // This can happen if any exceptions occur while setting up the connection, statement, etc. logger.error("The CQL query {} is invalid due to syntax error, authorization issue, or another " - + "validation problem", new Object[]{selectQuery}, qve); - session.remove(fileToProcess); + + "validation problem", selectQuery, qve); + if (fileToProcess != null) { + session.remove(fileToProcess); + } + context.yield(); + } + } catch (InterruptedException|ExecutionException ex) { + if (context.hasIncomingConnection()) { + logger.error("The CQL query {} has yielded an unknown error, routing {} to failure", + selectQuery, fileToProcess, ex); + + if (fileToProcess == null) { + fileToProcess = session.create(); + } + fileToProcess = session.penalize(fileToProcess); + session.transfer(fileToProcess, REL_FAILURE); + } else { + // This can happen if any exceptions occur while setting up the connection, statement, etc. + logger.error("The CQL query {} has run into an unknown error.", selectQuery, ex); + if (fileToProcess != null) { + session.remove(fileToProcess); + } context.yield(); } } catch (final ProcessException e) { if (context.hasIncomingConnection()) { logger.error("Unable to execute CQL select query {} for {} due to {}; routing to failure", - new Object[]{selectQuery, fileToProcess, e}); + selectQuery, fileToProcess, e); + if (fileToProcess == null) { + fileToProcess = session.create(); + } fileToProcess = session.penalize(fileToProcess); session.transfer(fileToProcess, REL_FAILURE); + } else { logger.error("Unable to execute CQL select query {} due to {}", - new Object[]{selectQuery, e}); - session.remove(fileToProcess); + selectQuery, e); + if (fileToProcess != null) { + session.remove(fileToProcess); + } context.yield(); } } + session.commitAsync(); + } + + private void handleException() { + } @@ -340,54 +436,90 @@ public class QueryCassandra extends AbstractCassandraProcessor { * @throws TimeoutException If a result set fetch has taken longer than the specified timeout * @throws ExecutionException If any error occurs during the result set fetch */ - public static long convertToAvroStream(final ResultSet rs, final OutputStream outStream, + public static long convertToAvroStream(final ResultSet rs, long maxRowsPerFlowFile, + final OutputStream outStream, long timeout, TimeUnit timeUnit) throws IOException, InterruptedException, TimeoutException, ExecutionException { final Schema schema = createSchema(rs); final GenericRecord rec = new GenericData.Record(schema); - final DatumWriter datumWriter = new GenericDatumWriter<>(schema); + try (final DataFileWriter dataFileWriter = new DataFileWriter<>(datumWriter)) { dataFileWriter.create(schema, outStream); - final ColumnDefinitions columnDefinitions = rs.getColumnDefinitions(); + ColumnDefinitions columnDefinitions = rs.getColumnDefinitions(); long nrOfRows = 0; + long rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); + if (columnDefinitions != null) { - do { - // Grab the ones we have - int rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); - if (rowsAvailableWithoutFetching == 0) { - // Get more - if (timeout <= 0 || timeUnit == null) { - rs.fetchMoreResults().get(); + // Grab the ones we have + if (rowsAvailableWithoutFetching == 0 + || rowsAvailableWithoutFetching < maxRowsPerFlowFile) { + // Get more + if (timeout <= 0 || timeUnit == null) { + rs.fetchMoreResults().get(); + } else { + rs.fetchMoreResults().get(timeout, timeUnit); + } + rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); + } + + if(maxRowsPerFlowFile == 0){ + maxRowsPerFlowFile = rowsAvailableWithoutFetching; + } + + Row row; + //Iterator it = rs.iterator(); + while(nrOfRows < maxRowsPerFlowFile){ + try { + row = rs.iterator().next(); + }catch (NoSuchElementException nsee){ + nrOfRows -= 1; + break; + } + + // iterator().next() is like iterator().one() => return null on end + // https://docs.datastax.com/en/drivers/java/2.0/com/datastax/driver/core/ResultSet.html#one-- + if(row == null){ + break; + } + + for (int i = 0; i < columnDefinitions.size(); i++) { + final DataType dataType = columnDefinitions.getType(i); + + if (row.isNull(i)) { + rec.put(i, null); } else { - rs.fetchMoreResults().get(timeout, timeUnit); + rec.put(i, getCassandraObject(row, i, dataType)); } } - for (Row row : rs) { - - for (int i = 0; i < columnDefinitions.size(); i++) { - final DataType dataType = columnDefinitions.getType(i); - - if (row.isNull(i)) { - rec.put(i, null); - } else { - rec.put(i, getCassandraObject(row, i, dataType)); - } - } - dataFileWriter.append(rec); - nrOfRows += 1; - - } - } while (!rs.isFullyFetched()); + dataFileWriter.append(rec); + nrOfRows += 1; + } } return nrOfRows; } } + private static String getFormattedDate(final Optional context, Date value) { + final String dateFormatPattern = context + .map(_context -> _context.getProperty(TIMESTAMP_FORMAT_PATTERN).getValue()) + .orElse(TIMESTAMP_FORMAT_PATTERN.getDefaultValue()); + SimpleDateFormat dateFormat = new SimpleDateFormat(dateFormatPattern); + dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); + return dateFormat.format(value); + } + + public static long convertToJsonStream(final ResultSet rs, long maxRowsPerFlowFile, + final OutputStream outStream, + Charset charset, long timeout, TimeUnit timeUnit) + throws IOException, InterruptedException, TimeoutException, ExecutionException { + return convertToJsonStream(Optional.empty(), rs, maxRowsPerFlowFile, outStream, charset, timeout, timeUnit); + } + /** * Converts a result set into an Json object and writes it to the given stream using the specified character set. * @@ -401,93 +533,108 @@ public class QueryCassandra extends AbstractCassandraProcessor { * @throws TimeoutException If a result set fetch has taken longer than the specified timeout * @throws ExecutionException If any error occurs during the result set fetch */ - public static long convertToJsonStream(final ResultSet rs, final OutputStream outStream, - Charset charset, long timeout, TimeUnit timeUnit) - throws IOException, InterruptedException, TimeoutException, ExecutionException { - return convertToJsonStream(Optional.empty(), rs, outStream, charset, timeout, timeUnit); - } - @VisibleForTesting - static long convertToJsonStream(final Optional context, final ResultSet rs, final OutputStream outStream, - Charset charset, long timeout, TimeUnit timeUnit) + public static long convertToJsonStream(final Optional context, + final ResultSet rs, long maxRowsPerFlowFile, + final OutputStream outStream, + Charset charset, long timeout, TimeUnit timeUnit) throws IOException, InterruptedException, TimeoutException, ExecutionException { try { // Write the initial object brace outStream.write("{\"results\":[".getBytes(charset)); - final ColumnDefinitions columnDefinitions = rs.getColumnDefinitions(); + ColumnDefinitions columnDefinitions = rs.getColumnDefinitions(); long nrOfRows = 0; - if (columnDefinitions != null) { - do { + long rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); - // Grab the ones we have - int rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); - if (rowsAvailableWithoutFetching == 0) { - // Get more - if (timeout <= 0 || timeUnit == null) { - rs.fetchMoreResults().get(); - } else { - rs.fetchMoreResults().get(timeout, timeUnit); - } + if (columnDefinitions != null) { + + // Grab the ones we have + if (rowsAvailableWithoutFetching == 0) { + // Get more + if (timeout <= 0 || timeUnit == null) { + rs.fetchMoreResults().get(); + } else { + rs.fetchMoreResults().get(timeout, timeUnit); + } + rowsAvailableWithoutFetching = rs.getAvailableWithoutFetching(); + } + + if(maxRowsPerFlowFile == 0){ + maxRowsPerFlowFile = rowsAvailableWithoutFetching; + } + + Row row; + while(nrOfRows < maxRowsPerFlowFile){ + try { + row = rs.iterator().next(); + }catch (NoSuchElementException nsee){ + nrOfRows -= 1; + break; } - for (Row row : rs) { - if (nrOfRows != 0) { + // iterator().next() is like iterator().one() => return null on end + // https://docs.datastax.com/en/drivers/java/2.0/com/datastax/driver/core/ResultSet.html#one-- + if(row == null){ + break; + } + + if (nrOfRows != 0) { + outStream.write(",".getBytes(charset)); + } + + outStream.write("{".getBytes(charset)); + for (int i = 0; i < columnDefinitions.size(); i++) { + final DataType dataType = columnDefinitions.getType(i); + final String colName = columnDefinitions.getName(i); + if (i != 0) { outStream.write(",".getBytes(charset)); } - outStream.write("{".getBytes(charset)); - for (int i = 0; i < columnDefinitions.size(); i++) { - final DataType dataType = columnDefinitions.getType(i); - final String colName = columnDefinitions.getName(i); - if (i != 0) { - outStream.write(",".getBytes(charset)); - } - if (row.isNull(i)) { - outStream.write(("\"" + colName + "\"" + ":null").getBytes(charset)); - } else { - Object value = getCassandraObject(row, i, dataType); - String valueString; - if (value instanceof List || value instanceof Set) { - boolean first = true; - StringBuilder sb = new StringBuilder("["); - for (Object element : ((Collection) value)) { - if (!first) { - sb.append(","); - } - sb.append(getJsonElement(context, element)); - first = false; + if (row.isNull(i)) { + outStream.write(("\"" + colName + "\"" + ":null").getBytes(charset)); + } else { + Object value = getCassandraObject(row, i, dataType); + String valueString; + if (value instanceof List || value instanceof Set) { + boolean first = true; + StringBuilder sb = new StringBuilder("["); + for (Object element : ((Collection) value)) { + if (!first) { + sb.append(","); } - sb.append("]"); - valueString = sb.toString(); - } else if (value instanceof Map) { - boolean first = true; - StringBuilder sb = new StringBuilder("{"); - for (Object element : ((Map) value).entrySet()) { - Map.Entry entry = (Map.Entry) element; - Object mapKey = entry.getKey(); - Object mapValue = entry.getValue(); - - if (!first) { - sb.append(","); - } - sb.append(getJsonElement(context, mapKey)); - sb.append(":"); - sb.append(getJsonElement(context, mapValue)); - first = false; - } - sb.append("}"); - valueString = sb.toString(); - } else { - valueString = getJsonElement(context, value); + sb.append(getJsonElement(context, element)); + first = false; } - outStream.write(("\"" + colName + "\":" - + valueString + "").getBytes(charset)); + sb.append("]"); + valueString = sb.toString(); + } else if (value instanceof Map) { + boolean first = true; + StringBuilder sb = new StringBuilder("{"); + for (Object element : ((Map) value).entrySet()) { + Map.Entry entry = (Map.Entry) element; + Object mapKey = entry.getKey(); + Object mapValue = entry.getValue(); + + if (!first) { + sb.append(","); + } + sb.append(getJsonElement(context, mapKey)); + sb.append(":"); + sb.append(getJsonElement(context, mapValue)); + first = false; + } + sb.append("}"); + valueString = sb.toString(); + } else { + valueString = getJsonElement(context, value); } + outStream.write(("\"" + colName + "\":" + + valueString + "").getBytes(charset)); } - nrOfRows += 1; - outStream.write("}".getBytes(charset)); } - } while (!rs.isFullyFetched()); + nrOfRows += 1; + outStream.write("}".getBytes(charset)); + } } return nrOfRows; } finally { @@ -511,15 +658,6 @@ public class QueryCassandra extends AbstractCassandraProcessor { } } - private static String getFormattedDate(final Optional context, Date value) { - final String dateFormatPattern = context - .map(_context -> _context.getProperty(TIMESTAMP_FORMAT_PATTERN).getValue()) - .orElse(TIMESTAMP_FORMAT_PATTERN.getDefaultValue()); - SimpleDateFormat dateFormat = new SimpleDateFormat(dateFormatPattern); - dateFormat.setTimeZone(TimeZone.getTimeZone("UTC")); - return dateFormat.format(value); - } - /** * Creates an Avro schema from the given result set. The metadata (column definitions, data types, etc.) is used * to determine a schema for Avro. @@ -577,4 +715,4 @@ public class QueryCassandra extends AbstractCassandraProcessor { } return builder.endRecord(); } -} +} \ No newline at end of file diff --git a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/CassandraQueryTestUtil.java b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/CassandraQueryTestUtil.java index d5e5a087c3..e31627aa0b 100644 --- a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/CassandraQueryTestUtil.java +++ b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/CassandraQueryTestUtil.java @@ -22,19 +22,20 @@ import com.datastax.driver.core.ResultSet; import com.datastax.driver.core.Row; import com.google.common.collect.Sets; import com.google.common.reflect.TypeToken; +import com.google.common.util.concurrent.ListenableFuture; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import java.text.SimpleDateFormat; +import java.util.Set; +import java.util.Map; +import java.util.List; +import java.util.HashMap; import java.util.Arrays; import java.util.Calendar; import java.util.Collections; import java.util.Date; import java.util.GregorianCalendar; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; import java.util.TimeZone; import static org.mockito.ArgumentMatchers.any; @@ -56,7 +57,7 @@ public class CassandraQueryTestUtil { TEST_DATE = c.getTime(); } - public static ResultSet createMockResultSet() throws Exception { + public static ResultSet createMockResultSet(boolean falseThenTrue) throws Exception { ResultSet resultSet = mock(ResultSet.class); ColumnDefinitions columnDefinitions = mock(ColumnDefinitions.class); when(columnDefinitions.size()).thenReturn(9); @@ -106,14 +107,28 @@ public class CassandraQueryTestUtil { }}, true, 3.0f, 4.0) ); + ListenableFuture future = mock(ListenableFuture.class); + when(future.get()).thenReturn(rows); + when(resultSet.fetchMoreResults()).thenReturn(future); + when(resultSet.iterator()).thenReturn(rows.iterator()); when(resultSet.all()).thenReturn(rows); when(resultSet.getAvailableWithoutFetching()).thenReturn(rows.size()); when(resultSet.isFullyFetched()).thenReturn(false).thenReturn(true); + if(falseThenTrue) { + when(resultSet.isExhausted()).thenReturn(false, true); + }else{ + when(resultSet.isExhausted()).thenReturn(true); + } when(resultSet.getColumnDefinitions()).thenReturn(columnDefinitions); + return resultSet; } + public static ResultSet createMockResultSet() throws Exception { + return createMockResultSet(true); + } + public static ResultSet createMockResultSetOneColumn() throws Exception { ResultSet resultSet = mock(ResultSet.class); ColumnDefinitions columnDefinitions = mock(ColumnDefinitions.class); @@ -143,10 +158,15 @@ public class CassandraQueryTestUtil { createRow("user2") ); + ListenableFuture future = mock(ListenableFuture.class); + when(future.get()).thenReturn(rows); + when(resultSet.fetchMoreResults()).thenReturn(future); + when(resultSet.iterator()).thenReturn(rows.iterator()); when(resultSet.all()).thenReturn(rows); when(resultSet.getAvailableWithoutFetching()).thenReturn(rows.size()); when(resultSet.isFullyFetched()).thenReturn(false).thenReturn(true); + when(resultSet.isExhausted()).thenReturn(false).thenReturn(true); when(resultSet.getColumnDefinitions()).thenReturn(columnDefinitions); return resultSet; } @@ -195,3 +215,4 @@ public class CassandraQueryTestUtil { return row; } } + diff --git a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraIT.java b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraIT.java new file mode 100644 index 0000000000..2bfd21f6b6 --- /dev/null +++ b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraIT.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.cassandra; + +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.ResultSet; +import com.datastax.driver.core.Session; +import com.datastax.driver.core.querybuilder.QueryBuilder; +import com.datastax.driver.core.querybuilder.Select; +import com.datastax.driver.core.querybuilder.Truncate; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.serialization.record.MockRecordParser; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.net.InetSocketAddress; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Testcontainers +public class QueryCassandraIT { + @Container + private static final CassandraContainer CASSANDRA_CONTAINER = new CassandraContainer(DockerImageName.parse("cassandra:4.1")); + + private static TestRunner putCassandraTestRunner; + private static TestRunner queryCassandraTestRunner; + private static MockRecordParser recordReader; + + private static Cluster cluster; + private static Session session; + + private static final int LOAD_FLOW_FILE_SIZE = 100; + private static final int LOAD_FLOW_FILE_BATCH_SIZE = 10; + + private static final String KEYSPACE = "sample_keyspace"; + private static final String TABLE = "sample_table"; + + @BeforeAll + public static void setup() throws InitializationException { + recordReader = new MockRecordParser(); + putCassandraTestRunner = TestRunners.newTestRunner(PutCassandraRecord.class); + queryCassandraTestRunner = TestRunners.newTestRunner(QueryCassandra.class); + + InetSocketAddress contactPoint = CASSANDRA_CONTAINER.getContactPoint(); + putCassandraTestRunner.setProperty(PutCassandraRecord.RECORD_READER_FACTORY, "reader"); + putCassandraTestRunner.setProperty(PutCassandraRecord.CONTACT_POINTS, contactPoint.getHostString() + ":" + contactPoint.getPort()); + putCassandraTestRunner.setProperty(PutCassandraRecord.KEYSPACE, KEYSPACE); + putCassandraTestRunner.setProperty(PutCassandraRecord.TABLE, TABLE); + putCassandraTestRunner.setProperty(PutCassandraRecord.CONSISTENCY_LEVEL, "SERIAL"); + putCassandraTestRunner.setProperty(PutCassandraRecord.BATCH_STATEMENT_TYPE, "LOGGED"); + putCassandraTestRunner.addControllerService("reader", recordReader); + putCassandraTestRunner.enableControllerService(recordReader); + + queryCassandraTestRunner.setProperty(QueryCassandra.CONTACT_POINTS, contactPoint.getHostName() + ":" + contactPoint.getPort()); + queryCassandraTestRunner.setProperty(QueryCassandra.FETCH_SIZE, "10"); + queryCassandraTestRunner.setProperty(QueryCassandra.OUTPUT_BATCH_SIZE, "10"); + queryCassandraTestRunner.setProperty(QueryCassandra.KEYSPACE, KEYSPACE); + queryCassandraTestRunner.setProperty(QueryCassandra.CQL_SELECT_QUERY, "select * from " + TABLE + ";"); + + cluster = Cluster.builder().addContactPoint(contactPoint.getHostName()) + .withPort(contactPoint.getPort()).build(); + session = cluster.connect(); + + String createKeyspace = "CREATE KEYSPACE IF NOT EXISTS " + KEYSPACE + " WITH replication = {'class':'SimpleStrategy','replication_factor':1};"; + String createTable = "CREATE TABLE IF NOT EXISTS " + KEYSPACE + "." + TABLE + "(id int PRIMARY KEY, uuid text, age int);"; + + session.execute(createKeyspace); + session.execute(createTable); + loadData(); + } + + private static void loadData() { + recordReader.addSchemaField("id", RecordFieldType.INT); + recordReader.addSchemaField("uuid", RecordFieldType.STRING); + recordReader.addSchemaField("age", RecordFieldType.INT); + int recordCount = 0; + + for (int i = 0; i resultsList = result.all() + .stream() + .map(r -> r.getInt(0)) + .collect(Collectors.toList()); + + return resultsList.size(); + } + + private void dropRecords() { + Truncate query = QueryBuilder.truncate(KEYSPACE, TABLE); + session.execute(query); + } + + @AfterAll + public static void shutdown() { + String dropKeyspace = "DROP KEYSPACE " + KEYSPACE; + String dropTable = "DROP TABLE IF EXISTS " + KEYSPACE + "." + TABLE; + + session.execute(dropTable); + session.execute(dropKeyspace); + + session.close(); + cluster.close(); + } +} diff --git a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraTest.java b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraTest.java index 9e329b0f7b..36cdc3af67 100644 --- a/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraTest.java +++ b/nifi-nar-bundles/nifi-cassandra-bundle/nifi-cassandra-processors/src/test/java/org/apache/nifi/processors/cassandra/QueryCassandraTest.java @@ -165,6 +165,7 @@ public class QueryCassandraTest { @Test public void testProcessorELConfigJsonOutput() { + setUpStandardProcessorConfig(); testRunner.setProperty(AbstractCassandraProcessor.CONTACT_POINTS, "${hosts}"); testRunner.setProperty(QueryCassandra.CQL_SELECT_QUERY, "${query}"); testRunner.setProperty(AbstractCassandraProcessor.PASSWORD, "${pass}"); @@ -172,6 +173,7 @@ public class QueryCassandraTest { testRunner.setProperty(AbstractCassandraProcessor.CHARSET, "${charset}"); testRunner.setProperty(QueryCassandra.QUERY_TIMEOUT, "${timeout}"); testRunner.setProperty(QueryCassandra.FETCH_SIZE, "${fetch}"); + testRunner.setProperty(QueryCassandra.MAX_ROWS_PER_FLOW_FILE, "${max-rows-per-flow}"); testRunner.setIncomingConnection(false); testRunner.assertValid(); @@ -181,6 +183,7 @@ public class QueryCassandraTest { testRunner.setVariable("charset", "UTF-8"); testRunner.setVariable("timeout", "30 sec"); testRunner.setVariable("fetch", "0"); + testRunner.setVariable("max-rows-per-flow", "0"); // Test JSON output testRunner.setProperty(QueryCassandra.OUTPUT_FORMAT, QueryCassandra.JSON_FORMAT); @@ -216,7 +219,7 @@ public class QueryCassandraTest { } @Test - public void testProcessorEmptyFlowFileAndExceptions() { + public void testProcessorEmptyFlowFile() { setUpStandardProcessorConfig(); // Run with empty flowfile @@ -224,36 +227,76 @@ public class QueryCassandraTest { processor.setExceptionToThrow(null); testRunner.enqueue("".getBytes()); testRunner.run(1, true, true); - testRunner.assertAllFlowFilesTransferred(QueryCassandra.REL_SUCCESS, 1); + testRunner.assertTransferCount(QueryCassandra.REL_SUCCESS, 1); testRunner.clearTransferState(); + } + + @Test + public void testProcessorEmptyFlowFileMaxRowsPerFlowFileEqOne() { + + processor = new MockQueryCassandraTwoRounds(); + testRunner = TestRunners.newTestRunner(processor); + + setUpStandardProcessorConfig(); + + testRunner.setIncomingConnection(true); + testRunner.setProperty(QueryCassandra.MAX_ROWS_PER_FLOW_FILE, "1"); + processor.setExceptionToThrow(null); + testRunner.enqueue("".getBytes()); + testRunner.run(1, true, true); + testRunner.assertTransferCount(QueryCassandra.REL_SUCCESS, 2); + testRunner.clearTransferState(); + } + + + @Test + public void testProcessorEmptyFlowFileAndNoHostAvailableException() { + setUpStandardProcessorConfig(); // Test exceptions processor.setExceptionToThrow(new NoHostAvailableException(new HashMap())); testRunner.enqueue("".getBytes()); testRunner.run(1, true, true); - testRunner.assertAllFlowFilesTransferred(QueryCassandra.REL_RETRY, 1); + testRunner.assertTransferCount(QueryCassandra.REL_RETRY, 1); testRunner.clearTransferState(); + } + + @Test + public void testProcessorEmptyFlowFileAndInetSocketAddressConsistencyLevelANY() { + setUpStandardProcessorConfig(); processor.setExceptionToThrow( new ReadTimeoutException(new SniEndPoint(new InetSocketAddress("localhost", 9042), ""), ConsistencyLevel.ANY, 0, 1, false)); testRunner.enqueue("".getBytes()); testRunner.run(1, true, true); - testRunner.assertAllFlowFilesTransferred(QueryCassandra.REL_RETRY, 1); + testRunner.assertTransferCount(QueryCassandra.REL_RETRY, 1); testRunner.clearTransferState(); + } + + @Test + public void testProcessorEmptyFlowFileAndInetSocketAddressDefault() { + setUpStandardProcessorConfig(); processor.setExceptionToThrow( new InvalidQueryException(new SniEndPoint(new InetSocketAddress("localhost", 9042), ""), "invalid query")); testRunner.enqueue("".getBytes()); testRunner.run(1, true, true); - testRunner.assertAllFlowFilesTransferred(QueryCassandra.REL_FAILURE, 1); + testRunner.assertTransferCount(QueryCassandra.REL_FAILURE, 1); testRunner.clearTransferState(); + } + + @Test + public void testProcessorEmptyFlowFileAndExceptionsProcessException() { + setUpStandardProcessorConfig(); processor.setExceptionToThrow(new ProcessException()); testRunner.enqueue("".getBytes()); testRunner.run(1, true, true); - testRunner.assertAllFlowFilesTransferred(QueryCassandra.REL_FAILURE, 1); + testRunner.assertTransferCount(QueryCassandra.REL_FAILURE, 1); } + // -- + @Test public void testCreateSchemaOneColumn() throws Exception { ResultSet rs = CassandraQueryTestUtil.createMockResultSetOneColumn(); @@ -264,7 +307,7 @@ public class QueryCassandraTest { @Test public void testCreateSchema() throws Exception { - ResultSet rs = CassandraQueryTestUtil.createMockResultSet(); + ResultSet rs = CassandraQueryTestUtil.createMockResultSet(true); Schema schema = QueryCassandra.createSchema(rs); assertNotNull(schema); assertEquals(Schema.Type.RECORD, schema.getType()); @@ -369,17 +412,20 @@ public class QueryCassandraTest { @Test public void testConvertToAvroStream() throws Exception { + setUpStandardProcessorConfig(); ResultSet rs = CassandraQueryTestUtil.createMockResultSet(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); - long numberOfRows = QueryCassandra.convertToAvroStream(rs, baos, 0, null); + long numberOfRows = QueryCassandra.convertToAvroStream(rs, 0, baos, 0, null); assertEquals(2, numberOfRows); } @Test public void testConvertToJSONStream() throws Exception { + setUpStandardProcessorConfig(); ResultSet rs = CassandraQueryTestUtil.createMockResultSet(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); - long numberOfRows = QueryCassandra.convertToJsonStream(rs, baos, StandardCharsets.UTF_8, 0, null); + long numberOfRows = QueryCassandra.convertToJsonStream(rs, 0, baos, StandardCharsets.UTF_8, + 0, null); assertEquals(2, numberOfRows); } @@ -391,7 +437,7 @@ public class QueryCassandraTest { DateFormat df = new SimpleDateFormat(QueryCassandra.TIMESTAMP_FORMAT_PATTERN.getDefaultValue()); df.setTimeZone(TimeZone.getTimeZone("UTC")); - long numberOfRows = QueryCassandra.convertToJsonStream(Optional.of(testRunner.getProcessContext()), rs, baos, + long numberOfRows = QueryCassandra.convertToJsonStream(Optional.of(testRunner.getProcessContext()), rs, 0, baos, StandardCharsets.UTF_8, 0, null); assertEquals(1, numberOfRows); @@ -411,7 +457,7 @@ public class QueryCassandraTest { DateFormat df = new SimpleDateFormat(customDateFormat); df.setTimeZone(TimeZone.getTimeZone("UTC")); - long numberOfRows = QueryCassandra.convertToJsonStream(Optional.of(context), rs, baos, StandardCharsets.UTF_8, 0, null); + long numberOfRows = QueryCassandra.convertToJsonStream(Optional.of(context), rs, 0, baos, StandardCharsets.UTF_8, 0, null); assertEquals(1, numberOfRows); Map>> map = new ObjectMapper().readValue(baos.toByteArray(), HashMap.class); @@ -425,6 +471,7 @@ public class QueryCassandraTest { testRunner.setProperty(QueryCassandra.CQL_SELECT_QUERY, "select * from test"); testRunner.setProperty(AbstractCassandraProcessor.PASSWORD, "password"); testRunner.setProperty(AbstractCassandraProcessor.USERNAME, "username"); + testRunner.setProperty(QueryCassandra.MAX_ROWS_PER_FLOW_FILE, "0"); } /** @@ -448,17 +495,21 @@ public class QueryCassandraTest { Configuration config = Configuration.builder().build(); when(mockCluster.getConfiguration()).thenReturn(config); ResultSetFuture future = mock(ResultSetFuture.class); - ResultSet rs = CassandraQueryTestUtil.createMockResultSet(); + ResultSet rs = CassandraQueryTestUtil.createMockResultSet(false); when(future.getUninterruptibly()).thenReturn(rs); + try { doReturn(rs).when(future).getUninterruptibly(anyLong(), any(TimeUnit.class)); } catch (TimeoutException te) { throw new IllegalArgumentException("Mocked cluster doesn't time out"); } + if (exceptionToThrow != null) { - when(mockSession.executeAsync(anyString())).thenThrow(exceptionToThrow); + when(mockSession.execute(anyString(), any(), any())).thenThrow(exceptionToThrow); + when(mockSession.execute(anyString())).thenThrow(exceptionToThrow); } else { - when(mockSession.executeAsync(anyString())).thenReturn(future); + when(mockSession.execute(anyString(),any(), any())).thenReturn(rs); + when(mockSession.execute(anyString())).thenReturn(rs); } } catch (Exception e) { fail(e.getMessage()); @@ -469,7 +520,52 @@ public class QueryCassandraTest { public void setExceptionToThrow(Exception e) { this.exceptionToThrow = e; } + } + private static class MockQueryCassandraTwoRounds extends MockQueryCassandra { + + private Exception exceptionToThrow = null; + + @Override + protected Cluster createCluster(List contactPoints, SSLContext sslContext, + String username, String password, String compressionType) { + Cluster mockCluster = mock(Cluster.class); + try { + Metadata mockMetadata = mock(Metadata.class); + when(mockMetadata.getClusterName()).thenReturn("cluster1"); + when(mockCluster.getMetadata()).thenReturn(mockMetadata); + Session mockSession = mock(Session.class); + when(mockCluster.connect()).thenReturn(mockSession); + when(mockCluster.connect(anyString())).thenReturn(mockSession); + Configuration config = Configuration.builder().build(); + when(mockCluster.getConfiguration()).thenReturn(config); + ResultSetFuture future = mock(ResultSetFuture.class); + ResultSet rs = CassandraQueryTestUtil.createMockResultSet(true); + when(future.getUninterruptibly()).thenReturn(rs); + + try { + doReturn(rs).when(future).getUninterruptibly(anyLong(), any(TimeUnit.class)); + } catch (TimeoutException te) { + throw new IllegalArgumentException("Mocked cluster doesn't time out"); + } + + if (exceptionToThrow != null) { + when(mockSession.execute(anyString(), any(), any())).thenThrow(exceptionToThrow); + when(mockSession.execute(anyString())).thenThrow(exceptionToThrow); + } else { + when(mockSession.execute(anyString(),any(), any())).thenReturn(rs); + when(mockSession.execute(anyString())).thenReturn(rs); + } + } catch (Exception e) { + fail(e.getMessage()); + } + return mockCluster; + } + + public void setExceptionToThrow(Exception e) { + this.exceptionToThrow = e; + } } } +