diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index c7bac9f6853..226058ee0c0 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -28,6 +28,7 @@ import java.util.Map.Entry; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.graph.ShortestPathStream; import org.apache.solr.client.solrj.io.ops.ConcatOperation; import org.apache.solr.client.solrj.io.ops.DistinctOperation; import org.apache.solr.client.solrj.io.ops.GroupOperation; @@ -115,6 +116,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("complement", ComplementStream.class) .withFunctionName("daemon", DaemonStream.class) .withFunctionName("topic", TopicStream.class) + .withFunctionName("shortestPath", ShortestPathStream.class) // metrics diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/ShortestPathStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/ShortestPathStream.java new file mode 100644 index 00000000000..bb9b09df3e8 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/ShortestPathStream.java @@ -0,0 +1,490 @@ +package org.apache.solr.client.solrj.io.graph; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.ArrayList; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor; +import org.apache.solr.client.solrj.io.stream.*; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.eq.FieldEqualitor; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.common.util.ExecutorUtil; +import org.apache.solr.common.util.SolrjNamedThreadFactory; + +public class ShortestPathStream extends TupleStream implements Expressible { + + private static final long serialVersionUID = 1; + + private String fromNode; + private String toNode; + private String fromField; + private String toField; + private int joinBatchSize; + private int maxDepth; + private String zkHost; + private String collection; + private LinkedList shortestPaths = new LinkedList(); + private boolean found; + private StreamContext streamContext; + private int threads; + private Map queryParams; + + public ShortestPathStream(String zkHost, + String collection, + String fromNode, + String toNode, + String fromField, + String toField, + Map queryParams, + int joinBatchSize, + int threads, + int maxDepth) { + + init(zkHost, + collection, + fromNode, + toNode, + fromField, + toField, + queryParams, + joinBatchSize, + threads, + maxDepth); + } + + public ShortestPathStream(StreamExpression expression, StreamFactory factory) throws IOException { + + String collectionName = factory.getValueOperand(expression, 0); + List namedParams = factory.getNamedOperands(expression); + StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost"); + + // Collection Name + if(null == collectionName) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression)); + } + + String fromNode = null; + StreamExpressionNamedParameter fromExpression = factory.getNamedOperand(expression, "from"); + + if(fromExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - from param is required",expression)); + } else { + fromNode = ((StreamExpressionValue)fromExpression.getParameter()).getValue(); + } + + String toNode = null; + StreamExpressionNamedParameter toExpression = factory.getNamedOperand(expression, "to"); + + if(toExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - to param is required", expression)); + } else { + toNode = ((StreamExpressionValue)toExpression.getParameter()).getValue(); + } + + String fromField = null; + String toField = null; + + StreamExpressionNamedParameter edgeExpression = factory.getNamedOperand(expression, "edge"); + + if(edgeExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - edge param is required", expression)); + } else { + String edge = ((StreamExpressionValue)edgeExpression.getParameter()).getValue(); + String[] fields = edge.split("="); + if(fields.length != 2) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - edge param separated by and = and must contain two fields", expression)); + } + fromField = fields[0].trim(); + toField = fields[1].trim(); + } + + int threads = 6; + + StreamExpressionNamedParameter threadsExpression = factory.getNamedOperand(expression, "threads"); + + if(threadsExpression != null) { + threads = Integer.parseInt(((StreamExpressionValue)threadsExpression.getParameter()).getValue()); + } + + int partitionSize = 250; + + StreamExpressionNamedParameter partitionExpression = factory.getNamedOperand(expression, "partitionSize"); + + if(partitionExpression != null) { + partitionSize = Integer.parseInt(((StreamExpressionValue)partitionExpression.getParameter()).getValue()); + } + + int maxDepth = 0; + + StreamExpressionNamedParameter depthExpression = factory.getNamedOperand(expression, "maxDepth"); + + if(depthExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - maxDepth param is required", expression)); + } else { + maxDepth = Integer.parseInt(((StreamExpressionValue) depthExpression.getParameter()).getValue()); + } + + Map params = new HashMap(); + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(!namedParam.getName().equals("zkHost") && + !namedParam.getName().equals("to") && + !namedParam.getName().equals("from") && + !namedParam.getName().equals("edge") && + !namedParam.getName().equals("maxDepth") && + !namedParam.getName().equals("threads") && + !namedParam.getName().equals("partitionSize")) + { + params.put(namedParam.getName(), namedParam.getParameter().toString().trim()); + } + } + + // zkHost, optional - if not provided then will look into factory list to get + String zkHost = null; + if(null == zkHostExpression){ + zkHost = factory.getCollectionZkHost(collectionName); + if(zkHost == null) { + zkHost = factory.getDefaultZkHost(); + } + } else if(zkHostExpression.getParameter() instanceof StreamExpressionValue) { + zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue(); + } + + if(null == zkHost){ + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName)); + } + + // We've got all the required items + init(zkHost, collectionName, fromNode, toNode, fromField, toField, params, partitionSize, threads, maxDepth); + } + + private void init(String zkHost, + String collection, + String fromNode, + String toNode, + String fromField, + String toField, + Map queryParams, + int joinBatchSize, + int threads, + int maxDepth) { + this.zkHost = zkHost; + this.collection = collection; + this.fromNode = fromNode; + this.toNode = toNode; + this.fromField = fromField; + this.toField = toField; + this.queryParams = queryParams; + this.joinBatchSize = joinBatchSize; + this.threads = threads; + this.maxDepth = maxDepth; + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + + StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); + + // collection + expression.addParameter(collection); + + Set entries = queryParams.entrySet(); + // parameters + for(Map.Entry param : entries){ + String value = param.getValue().toString(); + + // SOLR-8409: This is a special case where the params contain a " character + // Do note that in any other BASE streams with parameters where a " might come into play + // that this same replacement needs to take place. + value = value.replace("\"", "\\\""); + + expression.addParameter(new StreamExpressionNamedParameter(param.getKey().toString(), value)); + } + + expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost)); + expression.addParameter(new StreamExpressionNamedParameter("maxDepth", Integer.toString(maxDepth))); + expression.addParameter(new StreamExpressionNamedParameter("threads", Integer.toString(threads))); + expression.addParameter(new StreamExpressionNamedParameter("partitionSize", Integer.toString(joinBatchSize))); + expression.addParameter(new StreamExpressionNamedParameter("from", fromNode)); + expression.addParameter(new StreamExpressionNamedParameter("to", toNode)); + expression.addParameter(new StreamExpressionNamedParameter("edge", fromField+"="+toField)); + return expression; + } + + public void setStreamContext(StreamContext context) { + this.streamContext = context; + } + + public List children() { + List l = new ArrayList(); + return l; + } + + public void open() throws IOException { + + List>> allVisited = new ArrayList(); + Map visited = new HashMap(); + visited.put(this.fromNode, null); + + allVisited.add(visited); + int depth = 0; + Map> nextVisited = null; + List targets = new ArrayList(); + ExecutorService threadPool = null; + + try { + + threadPool = ExecutorUtil.newMDCAwareFixedThreadPool(threads, new SolrjNamedThreadFactory("ShortestPathStream")); + + //Breadth first search + TRAVERSE: + while (targets.size() == 0 && depth < maxDepth) { + Set nodes = visited.keySet(); + Iterator it = nodes.iterator(); + nextVisited = new HashMap(); + int batchCount = 0; + List queryNodes = new ArrayList(); + List futures = new ArrayList(); + JOIN: + //Queue up all the batches + while (it.hasNext()) { + String node = it.next(); + queryNodes.add(node); + ++batchCount; + if (batchCount == joinBatchSize || !it.hasNext()) { + try { + JoinRunner joinRunner = new JoinRunner(queryNodes); + Future> future = threadPool.submit(joinRunner); + futures.add(future); + } catch (Exception e) { + throw new RuntimeException(e); + } + batchCount = 0; + queryNodes = new ArrayList(); + } + } + + try { + //Process the batches as they become available + OUTER: + for (Future> future : futures) { + List edges = future.get(); + INNER: + for (Edge edge : edges) { + if (toNode.equals(edge.to)) { + targets.add(edge); + if(nextVisited.containsKey(edge.to)) { + List parents = nextVisited.get(edge.to); + parents.add(edge.from); + } else { + List parents = new ArrayList(); + parents.add(edge.from); + nextVisited.put(edge.to, parents); + } + } else { + if (!cycle(edge.to, allVisited)) { + if(nextVisited.containsKey(edge.to)) { + List parents = nextVisited.get(edge.to); + parents.add(edge.from); + } else { + List parents = new ArrayList(); + parents.add(edge.from); + nextVisited.put(edge.to, parents); + } + } + } + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + + allVisited.add(nextVisited); + visited = nextVisited; + ++depth; + } + } finally { + threadPool.shutdown(); + } + + Set finalPaths = new HashSet(); + if(targets.size() > 0) { + for(Edge edge : targets) { + List paths = new ArrayList(); + LinkedList path = new LinkedList(); + path.addFirst(edge.to); + paths.add(path); + //Walk back up the tree a collect the parent nodes. + INNER: + for (int i = allVisited.size() - 1; i >= 0; --i) { + Map> v = allVisited.get(i); + Iterator it = paths.iterator(); + List newPaths = new ArrayList(); + while(it.hasNext()) { + LinkedList p = it.next(); + List parents = v.get(p.peekFirst()); + if (parents != null) { + for(String parent : parents) { + LinkedList newPath = new LinkedList(); + newPath.addAll(p); + newPath.addFirst(parent); + newPaths.add(newPath); + } + paths = newPaths; + } + } + } + + for(LinkedList p : paths) { + String s = p.toString(); + if (!finalPaths.contains(s)){ + Tuple shortestPath = new Tuple(new HashMap()); + shortestPath.put("path", p); + shortestPaths.add(shortestPath); + finalPaths.add(s); + } + } + } + } + } + + private class JoinRunner implements Callable> { + + private List nodes; + private List edges = new ArrayList(); + + public JoinRunner(List nodes) { + this.nodes = nodes; + } + + public List call() { + + Map joinParams = new HashMap(); + String fl = fromField + "," + toField; + + joinParams.putAll(queryParams); + joinParams.put("fl", fl); + joinParams.put("qt", "/export"); + joinParams.put("sort", toField + " asc,"+fromField +" asc"); + + StringBuffer nodeQuery = new StringBuffer(); + + for(String node : nodes) { + nodeQuery.append(node).append(" "); + } + + String q = fromField + ":(" + nodeQuery.toString().trim() + ")"; + + joinParams.put("q", q); + TupleStream stream = null; + try { + stream = new UniqueStream(new CloudSolrStream(zkHost, collection, joinParams), new MultipleFieldEqualitor(new FieldEqualitor(toField), new FieldEqualitor(fromField))); + stream.setStreamContext(streamContext); + stream.open(); + BATCH: + while (true) { + Tuple tuple = stream.read(); + if (tuple.EOF) { + break BATCH; + } + String _toNode = tuple.getString(toField); + String _fromNode = tuple.getString(fromField); + Edge edge = new Edge(_fromNode, _toNode); + edges.add(edge); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + try { + stream.close(); + } catch(Exception ce) { + throw new RuntimeException(ce); + } + } + return edges; + } + } + + private class Edge { + + private String from; + private String to; + + public Edge(String from, String to) { + this.from = from; + this.to = to; + } + } + + private boolean cycle(String node, List>> allVisited) { + //Check all visited trees for each level to see if we've encountered this node before. + for(Map> visited : allVisited) { + if(visited.containsKey(node)) { + return true; + } + } + + return false; + } + + public void close() throws IOException { + this.found = false; + } + + public Tuple read() throws IOException { + if(shortestPaths.size() > 0) { + found = true; + Tuple t = shortestPaths.removeFirst(); + return t; + } else { + Map m = new HashMap(); + m.put("EOF", true); + if(!found) { + m.put("sorry", "No path found"); + } + return new Tuple(m); + } + } + + public int getCost() { + return 0; + } + + @Override + public StreamComparator getStreamSort() { + return null; + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/package-info.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/package-info.java new file mode 100644 index 00000000000..b34e0dd0e07 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Streaming Graph Traversals + **/ +package org.apache.solr.client.solrj.io.graph; + diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java new file mode 100644 index 00000000000..db58a905f7c --- /dev/null +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java @@ -0,0 +1,404 @@ +package org.apache.solr.client.solrj.io.graph; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import org.apache.lucene.util.LuceneTestCase; +import org.apache.lucene.util.LuceneTestCase.Slow; +import org.apache.solr.client.solrj.io.SolrClientCache; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.*; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.cloud.AbstractFullDistribZkTestBase; +import org.apache.solr.cloud.AbstractZkTestCase; +import org.apache.solr.common.SolrInputDocument; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * All base tests will be done with CloudSolrStream. Under the covers CloudSolrStream uses SolrStream so + * SolrStream will get fully exercised through these tests. + * + **/ + +@Slow +@LuceneTestCase.SuppressCodecs({"Lucene3x", "Lucene40","Lucene41","Lucene42","Lucene45"}) +public class GraphExpressionTest extends AbstractFullDistribZkTestBase { + + private static final String SOLR_HOME = getFile("solrj" + File.separator + "solr").getAbsolutePath(); + + static { + schemaString = "schema-streaming.xml"; + } + + @BeforeClass + public static void beforeSuperClass() { + AbstractZkTestCase.SOLRHOME = new File(SOLR_HOME()); + } + + @AfterClass + public static void afterSuperClass() { + + } + + protected String getCloudSolrConfig() { + return "solrconfig-streaming.xml"; + } + + + @Override + public String getSolrHome() { + return SOLR_HOME; + } + + public static String SOLR_HOME() { + return SOLR_HOME; + } + + @Before + @Override + public void setUp() throws Exception { + super.setUp(); + // we expect this time of exception as shards go up and down... + //ignoreException(".*"); + + System.setProperty("numShards", Integer.toString(sliceCount)); + } + + @Override + @After + public void tearDown() throws Exception { + super.tearDown(); + resetExceptionIgnores(); + } + + public GraphExpressionTest() { + super(); + sliceCount = 2; + } + + @Test + public void testAll() throws Exception{ + assertNotNull(cloudClient); + + handle.clear(); + handle.put("timestamp", SKIPVAL); + + waitForRecoveriesToFinish(false); + + del("*:*"); + commit(); + + testShortestPathStream(); + } + + private void testShortestPathStream() throws Exception { + + indexr(id, "0", "from_s", "jim", "to_s", "mike", "predicate_s", "knows"); + indexr(id, "1", "from_s", "jim", "to_s", "dave", "predicate_s", "knows"); + indexr(id, "2", "from_s", "jim", "to_s", "stan", "predicate_s", "knows"); + indexr(id, "3", "from_s", "dave", "to_s", "stan", "predicate_s", "knows"); + indexr(id, "4", "from_s", "dave", "to_s", "bill", "predicate_s", "knows"); + indexr(id, "5", "from_s", "dave", "to_s", "mike", "predicate_s", "knows"); + indexr(id, "20", "from_s", "dave", "to_s", "alex", "predicate_s", "knows"); + indexr(id, "21", "from_s", "alex", "to_s", "steve", "predicate_s", "knows"); + indexr(id, "6", "from_s", "stan", "to_s", "alice", "predicate_s", "knows"); + indexr(id, "7", "from_s", "stan", "to_s", "mary", "predicate_s", "knows"); + indexr(id, "8", "from_s", "stan", "to_s", "dave", "predicate_s", "knows"); + indexr(id, "10", "from_s", "mary", "to_s", "mike", "predicate_s", "knows"); + indexr(id, "11", "from_s", "mary", "to_s", "max", "predicate_s", "knows"); + indexr(id, "12", "from_s", "mary", "to_s", "jim", "predicate_s", "knows"); + indexr(id, "13", "from_s", "mary", "to_s", "steve", "predicate_s", "knows"); + + commit(); + + List tuples = null; + Set paths = null; + ShortestPathStream stream = null; + StreamContext context = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + context.setSolrClientCache(cache); + + StreamFactory factory = new StreamFactory() + .withCollectionZkHost("collection1", zkServer.getZkAddress()) + .withFunctionName("shortestPath", ShortestPathStream.class); + + Map params = new HashMap(); + params.put("fq", "predicate_s:knows"); + + stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " + + "from=\"jim\", " + + "to=\"steve\"," + + "edge=\"from_s=to_s\"," + + "fq=\"predicate_s:knows\","+ + "threads=\"3\","+ + "partitionSize=\"3\","+ + "maxDepth=\"6\")"); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 2); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, dave, alex, steve]")); + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + //Test with batch size of 1 + + params.put("fq", "predicate_s:knows"); + + stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " + + "from=\"jim\", " + + "to=\"steve\"," + + "edge=\"from_s=to_s\"," + + "fq=\"predicate_s:knows\","+ + "threads=\"3\","+ + "partitionSize=\"1\","+ + "maxDepth=\"6\")"); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 2); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, dave, alex, steve]")); + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + //Test with bad predicate + + + stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " + + "from=\"jim\", " + + "to=\"steve\"," + + "edge=\"from_s=to_s\"," + + "fq=\"predicate_s:crap\","+ + "threads=\"3\","+ + "partitionSize=\"3\","+ + "maxDepth=\"6\")"); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 0); + + //Test with depth 2 + + stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " + + "from=\"jim\", " + + "to=\"steve\"," + + "edge=\"from_s=to_s\"," + + "fq=\"predicate_s:knows\","+ + "threads=\"3\","+ + "partitionSize=\"3\","+ + "maxDepth=\"2\")"); + + + stream.setStreamContext(context); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 0); + + //Take out alex + params.put("fq", "predicate_s:knows NOT to_s:alex"); + + stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " + + "from=\"jim\", " + + "to=\"steve\"," + + "edge=\"from_s=to_s\"," + + "fq=\" predicate_s:knows NOT to_s:alex\","+ + "threads=\"3\","+ + "partitionSize=\"3\","+ + "maxDepth=\"6\")"); + + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + assertTrue(tuples.size() == 1); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + cache.close(); + del("*:*"); + commit(); + } + + protected List getTuples(TupleStream tupleStream) throws IOException { + tupleStream.open(); + List tuples = new ArrayList(); + for(Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) { + tuples.add(t); + } + tupleStream.close(); + return tuples; + } + protected boolean assertOrder(List tuples, int... ids) throws Exception { + return assertOrderOf(tuples, "id", ids); + } + protected boolean assertOrderOf(List tuples, String fieldName, int... ids) throws Exception { + int i = 0; + for(int val : ids) { + Tuple t = tuples.get(i); + Long tip = (Long)t.get(fieldName); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + protected boolean assertMapOrder(List tuples, int... ids) throws Exception { + int i = 0; + for(int val : ids) { + Tuple t = tuples.get(i); + List tip = t.getMaps("group"); + int id = (int)tip.get(0).get("id"); + if(id != val) { + throw new Exception("Found value:"+id+" expecting:"+val); + } + ++i; + } + return true; + } + + + protected boolean assertFields(List tuples, String ... fields) throws Exception{ + for(Tuple tuple : tuples){ + for(String field : fields){ + if(!tuple.fields.containsKey(field)){ + throw new Exception(String.format(Locale.ROOT, "Expected field '%s' not found", field)); + } + } + } + return true; + } + protected boolean assertNotFields(List tuples, String ... fields) throws Exception{ + for(Tuple tuple : tuples){ + for(String field : fields){ + if(tuple.fields.containsKey(field)){ + throw new Exception(String.format(Locale.ROOT, "Unexpected field '%s' found", field)); + } + } + } + return true; + } + + protected boolean assertGroupOrder(Tuple tuple, int... ids) throws Exception { + List group = (List)tuple.get("tuples"); + int i=0; + for(int val : ids) { + Map t = (Map)group.get(i); + Long tip = (Long)t.get("id"); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + public boolean assertLong(Tuple tuple, String fieldName, long l) throws Exception { + long lv = (long)tuple.get(fieldName); + if(lv != l) { + throw new Exception("Longs not equal:"+l+" : "+lv); + } + + return true; + } + + public boolean assertString(Tuple tuple, String fieldName, String expected) throws Exception { + String actual = (String)tuple.get(fieldName); + + if( (null == expected && null != actual) || + (null != expected && null == actual) || + (null != expected && !expected.equals(actual))){ + throw new Exception("Longs not equal:"+expected+" : "+actual); + } + + return true; + } + + protected boolean assertMaps(List maps, int... ids) throws Exception { + if(maps.size() != ids.length) { + throw new Exception("Expected id count != actual map count:"+ids.length+":"+maps.size()); + } + + int i=0; + for(int val : ids) { + Map t = maps.get(i); + Long tip = (Long)t.get("id"); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + private boolean assertList(List list, Object... vals) throws Exception { + + if(list.size() != vals.length) { + throw new Exception("Lists are not the same size:"+list.size() +" : "+vals.length); + } + + for(int i=0; i tuples = null; + Set paths = null; + ShortestPathStream stream = null; + String zkHost = zkServer.getZkAddress(); + StreamContext context = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + context.setSolrClientCache(cache); + + Map params = new HashMap(); + params.put("fq", "predicate_s:knows"); + + stream = new ShortestPathStream(zkHost, + "collection1", + "jim", + "steve", + "from_s", + "to_s", + params, + 20, + 3, + 6); + + + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 2); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, dave, alex, steve]")); + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + //Test with batch size of 1 + + params.put("fq", "predicate_s:knows"); + + stream = new ShortestPathStream(zkHost, + "collection1", + "jim", + "steve", + "from_s", + "to_s", + params, + 1, + 3, + 6); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 2); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, dave, alex, steve]")); + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + //Test with bad predicate + + params.put("fq", "predicate_s:crap"); + + stream = new ShortestPathStream(zkHost, + "collection1", + "jim", + "steve", + "from_s", + "to_s", + params, + 1, + 3, + 6); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 0); + + //Test with depth 2 + + params.put("fq", "predicate_s:knows"); + + stream = new ShortestPathStream(zkHost, + "collection1", + "jim", + "steve", + "from_s", + "to_s", + params, + 1, + 3, + 2); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + + assertTrue(tuples.size() == 0); + + + + //Take out alex + params.put("fq", "predicate_s:knows NOT to_s:alex"); + + stream = new ShortestPathStream(zkHost, + "collection1", + "jim", + "steve", + "from_s", + "to_s", + params, + 10, + 3, + 6); + + stream.setStreamContext(context); + paths = new HashSet(); + tuples = getTuples(stream); + assertTrue(tuples.size() == 1); + + for(Tuple tuple : tuples) { + paths.add(tuple.getStrings("path").toString()); + } + + assertTrue(paths.contains("[jim, stan, mary, steve]")); + + cache.close(); + del("*:*"); + commit(); + } + + @Test + public void streamTests() throws Exception { + assertNotNull(cloudClient); + + handle.clear(); + handle.put("timestamp", SKIPVAL); + + waitForRecoveriesToFinish(false); + + del("*:*"); + + commit(); + + testShortestPathStream(); + + } + + protected Map mapParams(String... vals) { + Map params = new HashMap(); + String k = null; + for(String val : vals) { + if(k == null) { + k = val; + } else { + params.put(k, val); + k = null; + } + } + + return params; + } + + protected List getTuples(TupleStream tupleStream) throws IOException { + tupleStream.open(); + List tuples = new ArrayList(); + for(;;) { + Tuple t = tupleStream.read(); + if(t.EOF) { + break; + } else { + tuples.add(t); + } + } + tupleStream.close(); + return tuples; + } + + protected Tuple getTuple(TupleStream tupleStream) throws IOException { + tupleStream.open(); + Tuple t = tupleStream.read(); + tupleStream.close(); + return t; + } + + + protected boolean assertOrder(List tuples, int... ids) throws Exception { + int i = 0; + for(int val : ids) { + Tuple t = tuples.get(i); + Long tip = (Long)t.get("id"); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + protected boolean assertGroupOrder(Tuple tuple, int... ids) throws Exception { + List group = (List)tuple.get("tuples"); + int i=0; + for(int val : ids) { + Map t = (Map)group.get(i); + Long tip = (Long)t.get("id"); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + protected boolean assertMaps(List maps, int... ids) throws Exception { + if(maps.size() != ids.length) { + throw new Exception("Expected id count != actual map count:"+ids.length+":"+maps.size()); + } + + int i=0; + for(int val : ids) { + Map t = maps.get(i); + Long tip = (Long)t.get("id"); + if(tip.intValue() != val) { + throw new Exception("Found value:"+tip.intValue()+" expecting:"+val); + } + ++i; + } + return true; + } + + public boolean assertLong(Tuple tuple, String fieldName, long l) throws Exception { + long lv = (long)tuple.get(fieldName); + if(lv != l) { + throw new Exception("Longs not equal:"+l+" : "+lv); + } + + return true; + } + + @Override + protected void indexr(Object... fields) throws Exception { + SolrInputDocument doc = getDoc(fields); + indexDoc(doc); + } + + private void attachStreamFactory(TupleStream tupleStream) { + StreamContext streamContext = new StreamContext(); + streamContext.setStreamFactory(streamFactory); + tupleStream.setStreamContext(streamContext); + } +} +