diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java index bd1c5c60e71..3f65bd8db12 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java @@ -18,14 +18,12 @@ package org.apache.solr.client.solrj.io.graph; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; +import java.lang.invoke.MethodHandles; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.*; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -51,6 +49,8 @@ import org.apache.solr.client.solrj.io.stream.metrics.Metric; import org.apache.solr.common.params.ModifiableSolrParams; import org.apache.solr.common.util.ExecutorUtil; import org.apache.solr.common.util.SolrNamedThreadFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static org.apache.solr.common.params.CommonParams.SORT; @@ -75,6 +75,11 @@ public class GatherNodesStream extends TupleStream implements Expressible { private Traversal traversal; private List metrics; private int maxDocFreq; + private SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ssX", Locale.ENGLISH); + private Set windowSet; + private int window = Integer.MIN_VALUE; + private int lag = 1; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); public GatherNodesStream(String zkHost, String collection, @@ -98,7 +103,9 @@ public class GatherNodesStream extends TupleStream implements Expressible { metrics, trackTraversal, scatter, - maxDocFreq); + maxDocFreq, + Integer.MIN_VALUE, + 1); } public GatherNodesStream(StreamExpression expression, StreamFactory factory) throws IOException { @@ -196,6 +203,20 @@ public class GatherNodesStream extends TupleStream implements Expressible { useDefaultTraversal = true; } + StreamExpressionNamedParameter windowExpression = factory.getNamedOperand(expression, "window"); + int timeWindow = Integer.MIN_VALUE; + + if(windowExpression != null) { + timeWindow = Integer.parseInt(((StreamExpressionValue) windowExpression.getParameter()).getValue()); + } + + StreamExpressionNamedParameter lagExpression = factory.getNamedOperand(expression, "lag"); + int timeLag = 1; + + if(lagExpression != null) { + timeLag = Integer.parseInt(((StreamExpressionValue) lagExpression.getParameter()).getValue()); + } + StreamExpressionNamedParameter docFreqExpression = factory.getNamedOperand(expression, "maxDocFreq"); int docFreq = -1; @@ -210,7 +231,10 @@ public class GatherNodesStream extends TupleStream implements Expressible { !namedParam.getName().equals("walk") && !namedParam.getName().equals("scatter") && !namedParam.getName().equals("maxDocFreq") && - !namedParam.getName().equals("trackTraversal")) + !namedParam.getName().equals("trackTraversal") && + !namedParam.getName().equals("window") && + !namedParam.getName().equals("lag") + ) { params.put(namedParam.getName(), namedParam.getParameter().toString().trim()); } @@ -242,7 +266,9 @@ public class GatherNodesStream extends TupleStream implements Expressible { metrics, trackTraversal, scatter, - docFreq); + docFreq, + timeWindow, + timeLag); } @SuppressWarnings({"unchecked"}) @@ -256,7 +282,9 @@ public class GatherNodesStream extends TupleStream implements Expressible { List metrics, boolean trackTraversal, Set scatter, - int maxDocFreq) { + int maxDocFreq, + int window, + int lag) { this.zkHost = zkHost; this.collection = collection; this.tupleStream = tupleStream; @@ -268,6 +296,13 @@ public class GatherNodesStream extends TupleStream implements Expressible { this.trackTraversal = trackTraversal; this.scatter = scatter; this.maxDocFreq = maxDocFreq; + this.window = window; + + if(window > Integer.MIN_VALUE) { + windowSet = new HashSet<>(); + } + + this.lag = lag; } @Override @@ -506,6 +541,26 @@ public class GatherNodesStream extends TupleStream implements Expressible { } + private String[] getTenSecondWindow(int size, int lag, String start) { + try { + String[] window = new String[size]; + Date date = this.dateFormat.parse(start); + Instant instant = date.toInstant(); + + for (int i = 0; i < size; i++) { + Instant windowInstant = instant.minus(10 * (i + lag), ChronoUnit.SECONDS); + String windowString = windowInstant.toString(); + windowString = windowString.substring(0, 18) + "0Z"; + window[i] = windowString; + } + + return window; + } catch(ParseException e) { + log.warn("Unparseable date:{}", String.valueOf(start)); + return new String[0]; + } + } + public void close() throws IOException { tupleStream.close(); } @@ -558,8 +613,33 @@ public class GatherNodesStream extends TupleStream implements Expressible { } } - joinBatch.add(value); - if (joinBatch.size() == 400) { + if(windowSet == null || (lag == 1 && !windowSet.contains(String.valueOf(value)))) { + joinBatch.add(value); + } + + if(window > Integer.MIN_VALUE && value != null) { + windowSet.add(value); + + /* + * A time window has been set. + * The join value is expected to be an ISO formatted time stamp. + * We derive the window and add it to the join values below. + */ + + String[] timeWindow = getTenSecondWindow(window, lag, value); + for(String windowString : timeWindow) { + if(!windowSet.contains(windowString)) { + /* + * Time windows can overlap, so make sure we don't query for the same timestamp more then once. + * This would cause duplicate results if overlapping windows are collected in different threads. + */ + joinBatch.add(windowString); + } + windowSet.add(windowString); + } + } + + if (joinBatch.size() >= 400) { JoinRunner joinRunner = new JoinRunner(joinBatch); @SuppressWarnings({"rawtypes"}) Future future = threadPool.submit(joinRunner); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java index 6b551d784d8..8e794bcbb33 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java @@ -30,7 +30,7 @@ public class Node { public Node(String id, boolean track) { this.id=id; if(track) { - ancestors = new HashSet<>(); + ancestors = new TreeSet<>(); } } diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java index f5f757db5f7..0a5777459e9 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java @@ -39,13 +39,7 @@ import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.ComparatorOrder; import org.apache.solr.client.solrj.io.comp.FieldComparator; -import org.apache.solr.client.solrj.io.stream.CloudSolrStream; -import org.apache.solr.client.solrj.io.stream.FacetStream; -import org.apache.solr.client.solrj.io.stream.HashJoinStream; -import org.apache.solr.client.solrj.io.stream.ScoreNodesStream; -import org.apache.solr.client.solrj.io.stream.SortStream; -import org.apache.solr.client.solrj.io.stream.StreamContext; -import org.apache.solr.client.solrj.io.stream.TupleStream; +import org.apache.solr.client.solrj.io.stream.*; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.client.solrj.io.stream.metrics.CountMetric; import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric; @@ -250,18 +244,18 @@ public class GraphExpressionTest extends SolrCloudTestCase { public void testGatherNodesStream() throws Exception { new UpdateRequest() - .add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20") - .add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30") - .add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1") - .add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2") - .add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5") - .add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10") - .add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20") - .add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10") - .add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10") - .add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40") - .add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10") - .add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10") + .add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20", "time_ten_seconds_s", "2020-09-24T18:23:50Z") + .add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30", "time_ten_seconds_s", "2020-09-24T18:23:40Z") + .add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1", "time_ten_seconds_s", "2020-09-24T18:23:30Z") + .add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2", "time_ten_seconds_s", "2020-09-24T18:23:20Z") + .add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5", "time_ten_seconds_s", "2020-09-24T18:23:10Z") + .add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:23:00Z") + .add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20", "time_ten_seconds_s", "2020-09-24T18:22:50Z") + .add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:40Z") + .add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:30Z") + .add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40", "time_ten_seconds_s", "2020-09-24T18:22:20Z") + .add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:10Z") + .add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:00Z") .commit(cluster.getSolrClient(), COLLECTION); List tuples = null; @@ -276,10 +270,12 @@ public class GraphExpressionTest extends SolrCloudTestCase { .withFunctionName("gatherNodes", GatherNodesStream.class) .withFunctionName("nodes", GatherNodesStream.class) .withFunctionName("search", CloudSolrStream.class) + .withFunctionName("random", RandomStream.class) .withFunctionName("count", CountMetric.class) .withFunctionName("avg", MeanMetric.class) .withFunctionName("sum", SumMetric.class) .withFunctionName("min", MinMetric.class) + .withFunctionName("sort", SortStream.class) .withFunctionName("max", MaxMetric.class); String expr = "nodes(collection1, " + @@ -293,6 +289,7 @@ public class GraphExpressionTest extends SolrCloudTestCase { Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); assertTrue(tuples.size() == 4); + assertTrue(tuples.get(0).getString("node").equals("basket1")); assertTrue(tuples.get(1).getString("node").equals("basket2")); assertTrue(tuples.get(2).getString("node").equals("basket3")); @@ -388,8 +385,43 @@ public class GraphExpressionTest extends SolrCloudTestCase { assertTrue(tuples.get(0).getString("node").equals("basket2")); assertTrue(tuples.get(1).getString("node").equals("basket3")); - cache.close(); + //Test the window without lag + + expr = "nodes(collection1, random(collection1, q=\"id:(1 2)\", fl=\"time_ten_seconds_s\"), walk=\"time_ten_seconds_s->time_ten_seconds_s\", gather=\"id\", window=\"3\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 5); + assertTrue(tuples.get(0).getString("node").equals("1")); + assertTrue(tuples.get(1).getString("node").equals("2")); + assertTrue(tuples.get(2).getString("node").equals("3")); + assertTrue(tuples.get(3).getString("node").equals("4")); + assertTrue(tuples.get(4).getString("node").equals("5")); + + + //Test window with lag + + expr = "nodes(collection1, random(collection1, q=\"id:(1)\", fl=\"time_ten_seconds_s\"), walk=\"time_ten_seconds_s->time_ten_seconds_s\", gather=\"id\", window=\"2\", lag=\"2\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 2); + assertTrue(tuples.get(0).getString("node").equals("3")); + assertTrue(tuples.get(1).getString("node").equals("4")); + cache.close(); }