SOLR-15132: Add temporal graph query to the nodes Streaming Expression

This commit is contained in:
Joel Bernstein 2021-02-12 15:18:54 -05:00
parent 4b113067d8
commit 4a42ecd936
3 changed files with 147 additions and 35 deletions

View File

@ -18,14 +18,12 @@
package org.apache.solr.client.solrj.io.graph;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.lang.invoke.MethodHandles;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
@ -51,6 +49,8 @@ import org.apache.solr.client.solrj.io.stream.metrics.Metric;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.SolrNamedThreadFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.solr.common.params.CommonParams.SORT;
@ -75,6 +75,11 @@ public class GatherNodesStream extends TupleStream implements Expressible {
private Traversal traversal;
private List<Metric> metrics;
private int maxDocFreq;
private SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ssX", Locale.ENGLISH);
private Set<String> windowSet;
private int window = Integer.MIN_VALUE;
private int lag = 1;
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public GatherNodesStream(String zkHost,
String collection,
@ -98,7 +103,9 @@ public class GatherNodesStream extends TupleStream implements Expressible {
metrics,
trackTraversal,
scatter,
maxDocFreq);
maxDocFreq,
Integer.MIN_VALUE,
1);
}
public GatherNodesStream(StreamExpression expression, StreamFactory factory) throws IOException {
@ -196,6 +203,20 @@ public class GatherNodesStream extends TupleStream implements Expressible {
useDefaultTraversal = true;
}
StreamExpressionNamedParameter windowExpression = factory.getNamedOperand(expression, "window");
int timeWindow = Integer.MIN_VALUE;
if(windowExpression != null) {
timeWindow = Integer.parseInt(((StreamExpressionValue) windowExpression.getParameter()).getValue());
}
StreamExpressionNamedParameter lagExpression = factory.getNamedOperand(expression, "lag");
int timeLag = 1;
if(lagExpression != null) {
timeLag = Integer.parseInt(((StreamExpressionValue) lagExpression.getParameter()).getValue());
}
StreamExpressionNamedParameter docFreqExpression = factory.getNamedOperand(expression, "maxDocFreq");
int docFreq = -1;
@ -210,7 +231,10 @@ public class GatherNodesStream extends TupleStream implements Expressible {
!namedParam.getName().equals("walk") &&
!namedParam.getName().equals("scatter") &&
!namedParam.getName().equals("maxDocFreq") &&
!namedParam.getName().equals("trackTraversal"))
!namedParam.getName().equals("trackTraversal") &&
!namedParam.getName().equals("window") &&
!namedParam.getName().equals("lag")
)
{
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
}
@ -242,7 +266,9 @@ public class GatherNodesStream extends TupleStream implements Expressible {
metrics,
trackTraversal,
scatter,
docFreq);
docFreq,
timeWindow,
timeLag);
}
@SuppressWarnings({"unchecked"})
@ -256,7 +282,9 @@ public class GatherNodesStream extends TupleStream implements Expressible {
List<Metric> metrics,
boolean trackTraversal,
Set<Traversal.Scatter> scatter,
int maxDocFreq) {
int maxDocFreq,
int window,
int lag) {
this.zkHost = zkHost;
this.collection = collection;
this.tupleStream = tupleStream;
@ -268,6 +296,13 @@ public class GatherNodesStream extends TupleStream implements Expressible {
this.trackTraversal = trackTraversal;
this.scatter = scatter;
this.maxDocFreq = maxDocFreq;
this.window = window;
if(window > Integer.MIN_VALUE) {
windowSet = new HashSet<>();
}
this.lag = lag;
}
@Override
@ -506,6 +541,26 @@ public class GatherNodesStream extends TupleStream implements Expressible {
}
private String[] getTenSecondWindow(int size, int lag, String start) {
try {
String[] window = new String[size];
Date date = this.dateFormat.parse(start);
Instant instant = date.toInstant();
for (int i = 0; i < size; i++) {
Instant windowInstant = instant.minus(10 * (i + lag), ChronoUnit.SECONDS);
String windowString = windowInstant.toString();
windowString = windowString.substring(0, 18) + "0Z";
window[i] = windowString;
}
return window;
} catch(ParseException e) {
log.warn("Unparseable date:{}", String.valueOf(start));
return new String[0];
}
}
public void close() throws IOException {
tupleStream.close();
}
@ -558,8 +613,33 @@ public class GatherNodesStream extends TupleStream implements Expressible {
}
}
if(windowSet == null || (lag == 1 && !windowSet.contains(String.valueOf(value)))) {
joinBatch.add(value);
if (joinBatch.size() == 400) {
}
if(window > Integer.MIN_VALUE && value != null) {
windowSet.add(value);
/*
* A time window has been set.
* The join value is expected to be an ISO formatted time stamp.
* We derive the window and add it to the join values below.
*/
String[] timeWindow = getTenSecondWindow(window, lag, value);
for(String windowString : timeWindow) {
if(!windowSet.contains(windowString)) {
/*
* Time windows can overlap, so make sure we don't query for the same timestamp more then once.
* This would cause duplicate results if overlapping windows are collected in different threads.
*/
joinBatch.add(windowString);
}
windowSet.add(windowString);
}
}
if (joinBatch.size() >= 400) {
JoinRunner joinRunner = new JoinRunner(joinBatch);
@SuppressWarnings({"rawtypes"})
Future future = threadPool.submit(joinRunner);

View File

@ -30,7 +30,7 @@ public class Node {
public Node(String id, boolean track) {
this.id=id;
if(track) {
ancestors = new HashSet<>();
ancestors = new TreeSet<>();
}
}

View File

@ -39,13 +39,7 @@ import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
import org.apache.solr.client.solrj.io.comp.FieldComparator;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.FacetStream;
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
import org.apache.solr.client.solrj.io.stream.SortStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.*;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.io.stream.metrics.CountMetric;
import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric;
@ -250,18 +244,18 @@ public class GraphExpressionTest extends SolrCloudTestCase {
public void testGatherNodesStream() throws Exception {
new UpdateRequest()
.add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20")
.add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30")
.add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1")
.add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2")
.add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5")
.add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10")
.add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20")
.add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10")
.add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10")
.add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40")
.add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10")
.add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10")
.add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20", "time_ten_seconds_s", "2020-09-24T18:23:50Z")
.add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30", "time_ten_seconds_s", "2020-09-24T18:23:40Z")
.add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1", "time_ten_seconds_s", "2020-09-24T18:23:30Z")
.add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2", "time_ten_seconds_s", "2020-09-24T18:23:20Z")
.add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5", "time_ten_seconds_s", "2020-09-24T18:23:10Z")
.add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:23:00Z")
.add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20", "time_ten_seconds_s", "2020-09-24T18:22:50Z")
.add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:40Z")
.add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:30Z")
.add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40", "time_ten_seconds_s", "2020-09-24T18:22:20Z")
.add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:10Z")
.add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10", "time_ten_seconds_s", "2020-09-24T18:22:00Z")
.commit(cluster.getSolrClient(), COLLECTION);
List<Tuple> tuples = null;
@ -276,10 +270,12 @@ public class GraphExpressionTest extends SolrCloudTestCase {
.withFunctionName("gatherNodes", GatherNodesStream.class)
.withFunctionName("nodes", GatherNodesStream.class)
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("random", RandomStream.class)
.withFunctionName("count", CountMetric.class)
.withFunctionName("avg", MeanMetric.class)
.withFunctionName("sum", SumMetric.class)
.withFunctionName("min", MinMetric.class)
.withFunctionName("sort", SortStream.class)
.withFunctionName("max", MaxMetric.class);
String expr = "nodes(collection1, " +
@ -293,6 +289,7 @@ public class GraphExpressionTest extends SolrCloudTestCase {
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 4);
assertTrue(tuples.get(0).getString("node").equals("basket1"));
assertTrue(tuples.get(1).getString("node").equals("basket2"));
assertTrue(tuples.get(2).getString("node").equals("basket3"));
@ -388,8 +385,43 @@ public class GraphExpressionTest extends SolrCloudTestCase {
assertTrue(tuples.get(0).getString("node").equals("basket2"));
assertTrue(tuples.get(1).getString("node").equals("basket3"));
cache.close();
//Test the window without lag
expr = "nodes(collection1, random(collection1, q=\"id:(1 2)\", fl=\"time_ten_seconds_s\"), walk=\"time_ten_seconds_s->time_ten_seconds_s\", gather=\"id\", window=\"3\")";
stream = (GatherNodesStream)factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 5);
assertTrue(tuples.get(0).getString("node").equals("1"));
assertTrue(tuples.get(1).getString("node").equals("2"));
assertTrue(tuples.get(2).getString("node").equals("3"));
assertTrue(tuples.get(3).getString("node").equals("4"));
assertTrue(tuples.get(4).getString("node").equals("5"));
//Test window with lag
expr = "nodes(collection1, random(collection1, q=\"id:(1)\", fl=\"time_ten_seconds_s\"), walk=\"time_ten_seconds_s->time_ten_seconds_s\", gather=\"id\", window=\"2\", lag=\"2\")";
stream = (GatherNodesStream)factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING));
assertTrue(tuples.size() == 2);
assertTrue(tuples.get(0).getString("node").equals("3"));
assertTrue(tuples.get(1).getString("node").equals("4"));
cache.close();
}