From d77574abbad62cdf80a8f8978ec439f8a7e6da72 Mon Sep 17 00:00:00 2001 From: jbernste Date: Mon, 18 Apr 2016 15:36:12 -0400 Subject: [PATCH] SOLR-8925: Add gatherNodes Streaming Expression to support breadth first traversals Conflicts: solr/core/src/java/org/apache/solr/handler/StreamHandler.java --- .../apache/solr/handler/StreamHandler.java | 14 +- .../solrj/io/graph/GatherNodesStream.java | 580 ++++++++++++++++++ .../solr/client/solrj/io/graph/Node.java | 90 +++ .../solr/client/solrj/io/graph/Traversal.java | 96 +++ .../solrj/io/graph/TraversalIterator.java | 120 ++++ .../client/solrj/io/stream/StreamContext.java | 4 + .../solrj/io/stream/metrics/CountMetric.java | 4 + .../solrj/io/stream/metrics/MaxMetric.java | 5 + .../solrj/io/stream/metrics/MeanMetric.java | 5 + .../solrj/io/stream/metrics/Metric.java | 2 + .../solrj/io/stream/metrics/MinMetric.java | 7 +- .../solrj/io/stream/metrics/SumMetric.java | 5 + .../solrj/io/graph/GraphExpressionTest.java | 402 +++++++++++- 13 files changed, 1329 insertions(+), 5 deletions(-) create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java index 2a19584e494..c33a7d4ec09 100644 --- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java +++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java @@ -28,6 +28,7 @@ import java.util.Map.Entry; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.graph.GatherNodesStream; import org.apache.solr.client.solrj.io.graph.ShortestPathStream; import org.apache.solr.client.solrj.io.ops.ConcatOperation; import org.apache.solr.client.solrj.io.ops.DistinctOperation; @@ -120,9 +121,8 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, .withFunctionName("daemon", DaemonStream.class) .withFunctionName("sort", SortStream.class) .withFunctionName("select", SelectStream.class) - - // graph streams - .withFunctionName("shortestPath", ShortestPathStream.class) + .withFunctionName("shortestPath", ShortestPathStream.class) + .withFunctionName("gatherNodes", GatherNodesStream.class) // metrics .withFunctionName("min", MinMetric.class) @@ -275,6 +275,14 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware, public Tuple read() { String msg = e.getMessage(); + + Throwable t = e.getCause(); + while(t != null) { + msg = t.getMessage(); + t = t.getCause(); + } + + Map m = new HashMap(); m.put("EOF", true); m.put("EXCEPTION", msg); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java new file mode 100644 index 00000000000..759aa0f8431 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/GatherNodesStream.java @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.graph; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.HashSet; +import java.util.ArrayList; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor; +import org.apache.solr.client.solrj.io.stream.*; +import org.apache.solr.client.solrj.io.stream.metrics.*; +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.comp.StreamComparator; +import org.apache.solr.client.solrj.io.eq.FieldEqualitor; +import org.apache.solr.client.solrj.io.stream.expr.Expressible; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; +import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue; +import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.common.util.ExecutorUtil; +import org.apache.solr.common.util.SolrjNamedThreadFactory; + +public class GatherNodesStream extends TupleStream implements Expressible { + + private String zkHost; + private String collection; + private StreamContext streamContext; + private Map queryParams; + private String traverseFrom; + private String traverseTo; + private String gather; + private boolean trackTraversal; + private boolean useDefaultTraversal; + + private TupleStream tupleStream; + private Set scatter; + private Iterator out; + private Traversal traversal; + private List metrics; + + public GatherNodesStream(String zkHost, + String collection, + TupleStream tupleStream, + String traverseFrom, + String traverseTo, + String gather, + Map queryParams, + List metrics, + boolean trackTraversal, + Set scatter) { + + init(zkHost, + collection, + tupleStream, + traverseFrom, + traverseTo, + gather, + queryParams, + metrics, + trackTraversal, + scatter); + } + + public GatherNodesStream(StreamExpression expression, StreamFactory factory) throws IOException { + + + String collectionName = factory.getValueOperand(expression, 0); + List namedParams = factory.getNamedOperands(expression); + StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost"); + + List streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class); + // Collection Name + if(null == collectionName) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression)); + } + + + Set scatter = new HashSet(); + + StreamExpressionNamedParameter scatterExpression = factory.getNamedOperand(expression, "scatter"); + + if(scatterExpression == null) { + scatter.add(Traversal.Scatter.LEAVES); + } else { + String s = ((StreamExpressionValue)scatterExpression.getParameter()).getValue(); + String[] sArray = s.split(","); + for(String sv : sArray) { + sv = sv.trim(); + if(Traversal.Scatter.BRANCHES.toString().equalsIgnoreCase(sv)) { + scatter.add(Traversal.Scatter.BRANCHES); + } else if (Traversal.Scatter.LEAVES.toString().equalsIgnoreCase(sv)) { + scatter.add(Traversal.Scatter.LEAVES); + } + } + } + + String gather = null; + StreamExpressionNamedParameter gatherExpression = factory.getNamedOperand(expression, "gather"); + + if(gatherExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - from param is required",expression)); + } else { + gather = ((StreamExpressionValue)gatherExpression.getParameter()).getValue(); + } + + String traverseFrom = null; + String traverseTo = null; + StreamExpressionNamedParameter edgeExpression = factory.getNamedOperand(expression, "walk"); + + TupleStream stream = null; + + if(edgeExpression == null) { + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - walk param is required", expression)); + } else { + if(streamExpressions.size() > 0) { + stream = factory.constructStream(streamExpressions.get(0)); + String edge = ((StreamExpressionValue) edgeExpression.getParameter()).getValue(); + String[] fields = edge.split("->"); + if (fields.length != 2) { + throw new IOException(String.format(Locale.ROOT, "invalid expression %s - walk param separated by an -> and must contain two fields", expression)); + } + traverseFrom = fields[0].trim(); + traverseTo = fields[1].trim(); + } else { + String edge = ((StreamExpressionValue) edgeExpression.getParameter()).getValue(); + String[] fields = edge.split("->"); + if (fields.length != 2) { + throw new IOException(String.format(Locale.ROOT, "invalid expression %s - walk param separated by an -> and must contain two fields", expression)); + } + + String[] rootNodes = fields[0].split(","); + List l = new ArrayList(); + for(String n : rootNodes) { + l.add(n.trim()); + } + + stream = new NodeStream(l); + traverseFrom = "node"; + traverseTo = fields[1].trim(); + } + } + + List metricExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, Metric.class); + List metrics = new ArrayList(); + for(int idx = 0; idx < metricExpressions.size(); ++idx){ + metrics.add(factory.constructMetric(metricExpressions.get(idx))); + } + + boolean trackTraversal = false; + + StreamExpressionNamedParameter trackExpression = factory.getNamedOperand(expression, "trackTraversal"); + + if(trackExpression != null) { + trackTraversal = Boolean.parseBoolean(((StreamExpressionValue) trackExpression.getParameter()).getValue()); + } else { + useDefaultTraversal = true; + } + + StreamExpressionNamedParameter scopeExpression = factory.getNamedOperand(expression, "localScope"); + + if(trackExpression != null) { + trackTraversal = Boolean.parseBoolean(((StreamExpressionValue) trackExpression.getParameter()).getValue()); + } + + Map params = new HashMap(); + for(StreamExpressionNamedParameter namedParam : namedParams){ + if(!namedParam.getName().equals("zkHost") && + !namedParam.getName().equals("gather") && + !namedParam.getName().equals("walk") && + !namedParam.getName().equals("scatter") && + !namedParam.getName().equals("trackTraversal")) + { + params.put(namedParam.getName(), namedParam.getParameter().toString().trim()); + } + } + + // zkHost, optional - if not provided then will look into factory list to get + String zkHost = null; + if(null == zkHostExpression){ + zkHost = factory.getCollectionZkHost(collectionName); + if(zkHost == null) { + zkHost = factory.getDefaultZkHost(); + } + } else if(zkHostExpression.getParameter() instanceof StreamExpressionValue) { + zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue(); + } + + if(null == zkHost){ + throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName)); + } + + // We've got all the required items + init(zkHost, + collectionName, + stream, + traverseFrom, + traverseTo , + gather, + params, + metrics, + trackTraversal, + scatter); + } + + private void init(String zkHost, + String collection, + TupleStream tupleStream, + String traverseFrom, + String traverseTo, + String gather, + Map queryParams, + List metrics, + boolean trackTraversal, + Set scatter) { + this.zkHost = zkHost; + this.collection = collection; + this.tupleStream = tupleStream; + this.traverseFrom = traverseFrom; + this.traverseTo = traverseTo; + this.gather = gather; + this.queryParams = queryParams; + this.metrics = metrics; + this.trackTraversal = trackTraversal; + this.scatter = scatter; + } + + @Override + public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { + + StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); + + // collection + expression.addParameter(collection); + + if(tupleStream instanceof Expressible){ + expression.addParameter(((Expressible)tupleStream).toExpression(factory)); + } + else{ + throw new IOException("This GatherNodesStream contains a non-expressible TupleStream - it cannot be converted to an expression"); + } + + Set entries = queryParams.entrySet(); + // parameters + for(Map.Entry param : entries){ + String value = param.getValue().toString(); + + // SOLR-8409: This is a special case where the params contain a " character + // Do note that in any other BASE streams with parameters where a " might come into play + // that this same replacement needs to take place. + value = value.replace("\"", "\\\""); + + expression.addParameter(new StreamExpressionNamedParameter(param.getKey().toString(), value)); + } + + if(metrics != null) { + for (Metric metric : metrics) { + expression.addParameter(metric.toExpression(factory)); + } + } + + expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost)); + expression.addParameter(new StreamExpressionNamedParameter("gather", zkHost)); + expression.addParameter(new StreamExpressionNamedParameter("walk", traverseFrom+"->"+traverseTo)); + expression.addParameter(new StreamExpressionNamedParameter("trackTraversal", Boolean.toString(trackTraversal))); + + StringBuilder buf = new StringBuilder(); + for(Traversal.Scatter sc : scatter) { + if(buf.length() > 0 ) { + buf.append(","); + } + buf.append(sc.toString()); + } + + expression.addParameter(new StreamExpressionNamedParameter("scatter", buf.toString())); + + return expression; + } + + public void setStreamContext(StreamContext context) { + this.traversal = (Traversal) context.get("traversal"); + if (traversal == null) { + //No traversal in the context. So create a new context and a new traversal. + //This ensures that two separate traversals in the same expression don't pollute each others traversal. + StreamContext localContext = new StreamContext(); + + localContext.numWorkers = context.numWorkers; + localContext.workerID = context.workerID; + localContext.setSolrClientCache(context.getSolrClientCache()); + localContext.setStreamFactory(context.getStreamFactory()); + + for(Object key :context.getEntries().keySet()) { + localContext.put(key, context.get(key)); + } + + traversal = new Traversal(); + + localContext.put("traversal", traversal); + + this.tupleStream.setStreamContext(localContext); + this.streamContext = localContext; + } else { + this.tupleStream.setStreamContext(context); + this.streamContext = context; + } + } + + public List children() { + List l = new ArrayList(); + l.add(tupleStream); + return l; + } + + public void open() throws IOException { + tupleStream.open(); + } + + private class JoinRunner implements Callable> { + + private List nodes; + private List edges = new ArrayList(); + + public JoinRunner(List nodes) { + this.nodes = nodes; + } + + public List call() { + + Map joinParams = new HashMap(); + Set flSet = new HashSet(); + flSet.add(gather); + flSet.add(traverseTo); + + //Add the metric columns + + if(metrics != null) { + for(Metric metric : metrics) { + for(String column : metric.getColumns()) { + flSet.add(column); + } + } + } + + if(queryParams.containsKey("fl")) { + String flString = (String)queryParams.get("fl"); + String[] flArray = flString.split(","); + for(String f : flArray) { + flSet.add(f.trim()); + } + } + + Iterator it = flSet.iterator(); + StringBuilder buf = new StringBuilder(); + while(it.hasNext()) { + buf.append(it.next()); + if(it.hasNext()) { + buf.append(","); + } + } + + joinParams.putAll(queryParams); + joinParams.put("fl", buf.toString()); + joinParams.put("qt", "/export"); + joinParams.put("sort", gather + " asc,"+traverseTo +" asc"); + + StringBuffer nodeQuery = new StringBuffer(); + + for(String node : nodes) { + nodeQuery.append(node).append(" "); + } + + String q = traverseTo + ":(" + nodeQuery.toString().trim() + ")"; + + + joinParams.put("q", q); + TupleStream stream = null; + try { + stream = new UniqueStream(new CloudSolrStream(zkHost, collection, joinParams), new MultipleFieldEqualitor(new FieldEqualitor(gather), new FieldEqualitor(traverseTo))); + stream.setStreamContext(streamContext); + stream.open(); + BATCH: + while (true) { + Tuple tuple = stream.read(); + if (tuple.EOF) { + break BATCH; + } + + edges.add(tuple); + } + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + try { + stream.close(); + } catch(Exception ce) { + throw new RuntimeException(ce); + } + } + return edges; + } + } + + + public void close() throws IOException { + tupleStream.close(); + } + + public Tuple read() throws IOException { + + if (out == null) { + List joinBatch = new ArrayList(); + List>> futures = new ArrayList(); + Map level = new HashMap(); + + ExecutorService threadPool = null; + try { + threadPool = ExecutorUtil.newMDCAwareFixedThreadPool(4, new SolrjNamedThreadFactory("GatherNodesStream")); + + Map roots = new HashMap(); + + while (true) { + Tuple tuple = tupleStream.read(); + if (tuple.EOF) { + if (joinBatch.size() > 0) { + JoinRunner joinRunner = new JoinRunner(joinBatch); + Future future = threadPool.submit(joinRunner); + futures.add(future); + } + break; + } + + String value = tuple.getString(traverseFrom); + + if(traversal.getDepth() == 0) { + //This gathers the root nodes + //We check to see if there are dupes in the root nodes because root streams may not have been uniqued. + String key = collection+"."+value; + if(!roots.containsKey(key)) { + Node node = new Node(value, trackTraversal); + if (metrics != null) { + List _metrics = new ArrayList(); + for (Metric metric : metrics) { + _metrics.add(metric.newInstance()); + } + node.setMetrics(_metrics); + } + + roots.put(key, node); + } else { + continue; + } + } + + joinBatch.add(value); + if (joinBatch.size() == 400) { + JoinRunner joinRunner = new JoinRunner(joinBatch); + Future future = threadPool.submit(joinRunner); + futures.add(future); + joinBatch = new ArrayList(); + } + } + + if(traversal.getDepth() == 0) { + traversal.addLevel(roots, collection, traverseFrom); + } + + this.traversal.setScatter(scatter); + + if(useDefaultTraversal) { + this.trackTraversal = traversal.getTrackTraversal(); + } else { + this.traversal.setTrackTraversal(trackTraversal); + } + + for (Future> future : futures) { + List tuples = future.get(); + for (Tuple tuple : tuples) { + String _traverseTo = tuple.getString(traverseTo); + String _gather = tuple.getString(gather); + String key = collection + "." + _gather; + if (!traversal.visited(key, _traverseTo, tuple)) { + Node node = level.get(key); + if (node != null) { + node.add((traversal.getDepth()-1)+"^"+_traverseTo, tuple); + } else { + node = new Node(_gather, trackTraversal); + if (metrics != null) { + List _metrics = new ArrayList(); + for (Metric metric : metrics) { + _metrics.add(metric.newInstance()); + } + node.setMetrics(_metrics); + } + node.add((traversal.getDepth()-1)+"^"+_traverseTo, tuple); + level.put(key, node); + } + } + } + } + + traversal.addLevel(level, collection, gather); + out = traversal.iterator(); + } catch(Exception e) { + throw new RuntimeException(e); + } finally { + threadPool.shutdown(); + } + } + + if (out.hasNext()) { + return out.next(); + } else { + Map map = new HashMap(); + map.put("EOF", true); + Tuple tuple = new Tuple(map); + return tuple; + } + } + + public int getCost() { + return 0; + } + + @Override + public StreamComparator getStreamSort() { + return null; + } + + class NodeStream extends TupleStream { + + private List ids; + private Iterator it; + + public NodeStream(List ids) { + this.ids = ids; + } + + public void open() {this.it = ids.iterator();} + public void close() {} + public StreamComparator getStreamSort() {return null;} + public List children() {return new ArrayList();} + public void setStreamContext(StreamContext context) {} + + public Tuple read() { + HashMap map = new HashMap(); + if(it.hasNext()) { + map.put("node",it.next()); + return new Tuple(map); + } else { + + map.put("EOF", true); + return new Tuple(map); + } + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java new file mode 100644 index 00000000000..befa5a7721c --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Node.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.graph; + +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.stream.metrics.*; +import java.util.*; + +public class Node { + + private String id; + private List metrics; + private Set ancestors; + + public Node(String id, boolean track) { + this.id=id; + if(track) { + ancestors = new HashSet(); + } + } + + public void setMetrics(List metrics) { + this.metrics = metrics; + } + + public void add(String ancestor, Tuple tuple) { + if(ancestors != null) { + ancestors.add(ancestor); + } + + if(metrics != null) { + for(Metric metric : metrics) { + metric.update(tuple); + } + } + } + + public Tuple toTuple(String collection, String field, int level, Traversal traversal) { + Map map = new HashMap(); + + map.put("node", id); + map.put("collection", collection); + map.put("field", field); + map.put("level", level); + + boolean prependCollection = traversal.isMultiCollection(); + List cols = traversal.getCollections(); + + if(ancestors != null) { + List l = new ArrayList(); + for(String ancestor : ancestors) { + String[] ancestorParts = ancestor.split("\\^"); + + if(prependCollection) { + //prepend the collection + int colIndex = Integer.parseInt(ancestorParts[0]); + l.add(cols.get(colIndex)+"/"+ancestorParts[1]); + } else { + // Use only the ancestor id. + l.add(ancestorParts[1]); + } + } + + map.put("ancestors", l); + } + + if(metrics != null) { + for(Metric metric : metrics) { + map.put(metric.getIdentifier(), metric.getValue()); + } + } + + return new Tuple(map); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java new file mode 100644 index 00000000000..43d23b33b19 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/Traversal.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.graph; + +import org.apache.solr.client.solrj.io.Tuple; +import java.util.*; + +public class Traversal { + + private List> graph = new ArrayList(); + private List fields = new ArrayList(); + private List collections = new ArrayList(); + private Set scatter = new HashSet(); + private Set collectionSet = new HashSet(); + private boolean trackTraversal; + private int depth; + + public void addLevel(Map level, String collection, String field) { + graph.add(level); + collections.add(collection); + collectionSet.add(collection); + fields.add(field); + ++depth; + } + + public int getDepth() { + return depth; + } + + public boolean getTrackTraversal() { + return this.trackTraversal; + } + + public boolean visited(String nodeId, String ancestorId, Tuple tuple) { + for(Map level : graph) { + Node node = level.get(nodeId); + if(node != null) { + node.add(depth+"^"+ancestorId, tuple); + return true; + } + } + return false; + } + + public boolean isMultiCollection() { + return collectionSet.size() > 1; + } + + public List> getGraph() { + return graph; + } + + public void setScatter(Set scatter) { + this.scatter = scatter; + } + + public Set getScatter() { + return this.scatter; + } + + public void setTrackTraversal(boolean trackTraversal) { + this.trackTraversal = trackTraversal; + } + + public List getCollections() { + return this.collections; + } + + public List getFields() { + return this.fields; + } + + public enum Scatter { + BRANCHES, + LEAVES; + } + + public Iterator iterator() { + return new TraversalIterator(this, scatter); + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java new file mode 100644 index 00000000000..7cfe3756fb7 --- /dev/null +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/graph/TraversalIterator.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.solr.client.solrj.io.graph; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.graph.Traversal.Scatter; + +class TraversalIterator implements Iterator { + + private List> graph; + private List collections; + private List fields; + + private Iterator> graphIterator; + private Iterator levelIterator; + + private Iterator fieldIterator; + private Iterator collectionIterator; + private Iterator levelNumIterator; + private String outField; + private String outCollection; + private int outLevel; + private Traversal traversal; + + public TraversalIterator(Traversal traversal, Set scatter) { + this.traversal = traversal; + graph = traversal.getGraph(); + collections = traversal.getCollections(); + fields = traversal.getFields(); + + List outCollections = new ArrayList(); + List outFields = new ArrayList(); + List levelNums = new ArrayList(); + List> levelIterators = new ArrayList(); + + if(scatter.contains(Scatter.BRANCHES)) { + if(graph.size() > 1) { + for(int i=0; i graphLevel = graph.get(i); + String collection = collections.get(i); + String field = fields.get(i); + outCollections.add(collection); + outFields.add(field); + levelNums.add(i); + levelIterators.add(graphLevel.values().iterator()); + } + } + } + + if(scatter.contains(Scatter.LEAVES)) { + int leavesLevel = graph.size() > 1 ? graph.size()-1 : 0 ; + Map graphLevel = graph.get(leavesLevel); + String collection = collections.get(leavesLevel); + String field = fields.get(leavesLevel); + levelNums.add(leavesLevel); + outCollections.add(collection); + outFields.add(field); + levelIterators.add(graphLevel.values().iterator()); + } + + graphIterator = levelIterators.iterator(); + levelIterator = graphIterator.next(); + + fieldIterator = outFields.iterator(); + collectionIterator = outCollections.iterator(); + levelNumIterator = levelNums.iterator(); + + outField = fieldIterator.next(); + outCollection = collectionIterator.next(); + outLevel = levelNumIterator.next(); + } + + @Override + public boolean hasNext() { + if(levelIterator.hasNext()) { + return true; + } else { + if(graphIterator.hasNext()) { + levelIterator = graphIterator.next(); + outField = fieldIterator.next(); + outCollection = collectionIterator.next(); + outLevel = levelNumIterator.next(); + return hasNext(); + } else { + return false; + } + } + } + + @Override + public Tuple next() { + if(hasNext()) { + Node node = levelIterator.next(); + return node.toTuple(outCollection, outField, outLevel, traversal); + } else { + return null; + } + } +} \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java index ff0aefa4d9a..87e30356ec3 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StreamContext.java @@ -49,6 +49,10 @@ public class StreamContext implements Serializable{ this.entries.put(key, value); } + public Map getEntries() { + return this.entries; + } + public void setSolrClientCache(SolrClientCache clientCache) { this.clientCache = clientCache; } diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java index 0e19177ff05..445b530163d 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/CountMetric.java @@ -49,6 +49,10 @@ public class CountMetric extends Metric implements Serializable { init(functionName); } + + public String[] getColumns() { + return new String[0]; + } private void init(String functionName){ setFunctionName(functionName); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java index 8f2069e472d..0594bf42249 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MaxMetric.java @@ -67,6 +67,11 @@ public class MaxMetric extends Metric implements Serializable { } } + public String[] getColumns() { + String[] cols = {columnName}; + return cols; + } + public void update(Tuple tuple) { Object o = tuple.get(columnName); if(o instanceof Double) { diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java index 0a5726c95bc..097e04b822b 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MeanMetric.java @@ -80,6 +80,11 @@ public class MeanMetric extends Metric implements Serializable { return new MeanMetric(columnName); } + public String[] getColumns() { + String[] cols = {columnName}; + return cols; + } + public double getValue() { double dcount = (double)count; if(longSum == 0) { diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java index e7321828dd0..07a400a50e8 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/Metric.java @@ -54,4 +54,6 @@ public abstract class Metric implements Serializable, Expressible { public abstract double getValue(); public abstract void update(Tuple tuple); public abstract Metric newInstance(); + public abstract String[] getColumns(); + } \ No newline at end of file diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java index 7c6060e9ddf..0a565809080 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/MinMetric.java @@ -56,7 +56,12 @@ public class MinMetric extends Metric { setFunctionName(functionName); setIdentifier(functionName, "(", columnName, ")"); } - + + + public String[] getColumns() { + String[] cols = {columnName}; + return cols; + } public double getValue() { if(longMin == Long.MAX_VALUE) { diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java index 805f9781283..578dae764fb 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/metrics/SumMetric.java @@ -58,6 +58,11 @@ public class SumMetric extends Metric implements Serializable { setIdentifier(functionName, "(", columnName, ")"); } + public String[] getColumns() { + String[] cols = {columnName}; + return cols; + } + public void update(Tuple tuple) { Object o = tuple.get(columnName); if(o instanceof Double) { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java index db58a905f7c..b5231e295c8 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java @@ -20,6 +20,7 @@ package org.apache.solr.client.solrj.io.graph; import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -31,8 +32,15 @@ import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.LuceneTestCase.Slow; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; +import org.apache.solr.client.solrj.io.comp.ComparatorOrder; +import org.apache.solr.client.solrj.io.comp.FieldComparator; import org.apache.solr.client.solrj.io.stream.*; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; +import org.apache.solr.client.solrj.io.stream.metrics.CountMetric; +import org.apache.solr.client.solrj.io.stream.metrics.MaxMetric; +import org.apache.solr.client.solrj.io.stream.metrics.MeanMetric; +import org.apache.solr.client.solrj.io.stream.metrics.MinMetric; +import org.apache.solr.client.solrj.io.stream.metrics.SumMetric; import org.apache.solr.cloud.AbstractFullDistribZkTestBase; import org.apache.solr.cloud.AbstractZkTestCase; import org.apache.solr.common.SolrInputDocument; @@ -117,6 +125,8 @@ public class GraphExpressionTest extends AbstractFullDistribZkTestBase { commit(); testShortestPathStream(); + testGatherNodesStream(); + testGatherNodesFriendsStream(); } private void testShortestPathStream() throws Exception { @@ -265,9 +275,399 @@ public class GraphExpressionTest extends AbstractFullDistribZkTestBase { commit(); } + + private void testGatherNodesStream() throws Exception { + + indexr(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20"); + indexr(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30"); + indexr(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1"); + indexr(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2"); + indexr(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5"); + indexr(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10"); + indexr(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20"); + indexr(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10"); + indexr(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10"); + indexr(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40"); + indexr(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10"); + indexr(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10"); + + commit(); + + List tuples = null; + Set paths = null; + GatherNodesStream stream = null; + StreamContext context = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + context.setSolrClientCache(cache); + + StreamFactory factory = new StreamFactory() + .withCollectionZkHost("collection1", zkServer.getZkAddress()) + .withFunctionName("gatherNodes", GatherNodesStream.class) + .withFunctionName("search", CloudSolrStream.class) + .withFunctionName("count", CountMetric.class) + .withFunctionName("avg", MeanMetric.class) + .withFunctionName("sum", SumMetric.class) + .withFunctionName("min", MinMetric.class) + .withFunctionName("max", MaxMetric.class); + + String expr = "gatherNodes(collection1, " + + "walk=\"product1->product_s\"," + + "gather=\"basket_s\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 4); + assertTrue(tuples.get(0).getString("node").equals("basket1")); + assertTrue(tuples.get(1).getString("node").equals("basket2")); + assertTrue(tuples.get(2).getString("node").equals("basket3")); + assertTrue(tuples.get(3).getString("node").equals("basket4")); + + String expr2 = "gatherNodes(collection1, " + + expr+","+ + "walk=\"node->basket_s\"," + + "gather=\"product_s\", count(*), avg(price_f), sum(price_f), min(price_f), max(price_f))"; + + stream = (GatherNodesStream)factory.constructStream(expr2); + + context = new StreamContext(); + context.setSolrClientCache(cache); + + stream.setStreamContext(context); + + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + + + assertTrue(tuples.size() == 5); + + + assertTrue(tuples.get(0).getString("node").equals("product3")); + assertTrue(tuples.get(0).getDouble("count(*)").equals(3.0D)); + + assertTrue(tuples.get(1).getString("node").equals("product4")); + assertTrue(tuples.get(1).getDouble("count(*)").equals(2.0D)); + assertTrue(tuples.get(1).getDouble("avg(price_f)").equals(30.0D)); + assertTrue(tuples.get(1).getDouble("sum(price_f)").equals(60.0D)); + assertTrue(tuples.get(1).getDouble("min(price_f)").equals(20.0D)); + assertTrue(tuples.get(1).getDouble("max(price_f)").equals(40.0D)); + + assertTrue(tuples.get(2).getString("node").equals("product5")); + assertTrue(tuples.get(2).getDouble("count(*)").equals(1.0D)); + assertTrue(tuples.get(3).getString("node").equals("product6")); + assertTrue(tuples.get(3).getDouble("count(*)").equals(1.0D)); + assertTrue(tuples.get(4).getString("node").equals("product7")); + assertTrue(tuples.get(4).getDouble("count(*)").equals(1.0D)); + + //Test list of root nodes + expr = "gatherNodes(collection1, " + + "walk=\"product4, product7->product_s\"," + + "gather=\"basket_s\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + tuples = getTuples(stream); + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 3); + assertTrue(tuples.get(0).getString("node").equals("basket2")); + assertTrue(tuples.get(1).getString("node").equals("basket3")); + assertTrue(tuples.get(2).getString("node").equals("basket4")); + + //Test with negative filter query + + expr = "gatherNodes(collection1, " + + "walk=\"product4, product7->product_s\"," + + "gather=\"basket_s\", fq=\"-basket_s:basket4\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 2); + assertTrue(tuples.get(0).getString("node").equals("basket2")); + assertTrue(tuples.get(1).getString("node").equals("basket3")); + + cache.close(); + del("*:*"); + commit(); + } + + private void testGatherNodesFriendsStream() throws Exception { + + indexr(id, "0", "from_s", "bill", "to_s", "jim", "message_t", "Hello jim"); + indexr(id, "1", "from_s", "bill", "to_s", "sam", "message_t", "Hello sam"); + indexr(id, "2", "from_s", "bill", "to_s", "max", "message_t", "Hello max"); + indexr(id, "3", "from_s", "max", "to_s", "kip", "message_t", "Hello kip"); + indexr(id, "4", "from_s", "sam", "to_s", "steve", "message_t", "Hello steve"); + indexr(id, "5", "from_s", "jim", "to_s", "ann", "message_t", "Hello steve"); + + commit(); + + List tuples = null; + Set paths = null; + GatherNodesStream stream = null; + StreamContext context = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + context.setSolrClientCache(cache); + + StreamFactory factory = new StreamFactory() + .withCollectionZkHost("collection1", zkServer.getZkAddress()) + .withFunctionName("gatherNodes", GatherNodesStream.class) + .withFunctionName("search", CloudSolrStream.class) + .withFunctionName("count", CountMetric.class) + .withFunctionName("hashJoin", HashJoinStream.class) + .withFunctionName("avg", MeanMetric.class) + .withFunctionName("sum", SumMetric.class) + .withFunctionName("min", MinMetric.class) + .withFunctionName("max", MaxMetric.class); + + String expr = "gatherNodes(collection1, " + + "walk=\"bill->from_s\"," + + "gather=\"to_s\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 3); + assertTrue(tuples.get(0).getString("node").equals("jim")); + assertTrue(tuples.get(1).getString("node").equals("max")); + assertTrue(tuples.get(2).getString("node").equals("sam")); + + //Test scatter branches, leaves and trackTraversal + + expr = "gatherNodes(collection1, " + + "walk=\"bill->from_s\"," + + "gather=\"to_s\","+ + "scatter=\"branches, leaves\", trackTraversal=\"true\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 4); + assertTrue(tuples.get(0).getString("node").equals("bill")); + assertTrue(tuples.get(0).getLong("level").equals(new Long(0))); + assertTrue(tuples.get(0).getStrings("ancestors").size() == 0); + assertTrue(tuples.get(1).getString("node").equals("jim")); + assertTrue(tuples.get(1).getLong("level").equals(new Long(1))); + List ancestors = tuples.get(1).getStrings("ancestors"); + System.out.println("##################### Ancestors:"+ancestors); + assert(ancestors.size() == 1); + assert(ancestors.get(0).equals("bill")); + + assertTrue(tuples.get(2).getString("node").equals("max")); + assertTrue(tuples.get(2).getLong("level").equals(new Long(1))); + ancestors = tuples.get(2).getStrings("ancestors"); + assert(ancestors.size() == 1); + assert(ancestors.get(0).equals("bill")); + + assertTrue(tuples.get(3).getString("node").equals("sam")); + assertTrue(tuples.get(3).getLong("level").equals(new Long(1))); + ancestors = tuples.get(3).getStrings("ancestors"); + assert(ancestors.size() == 1); + assert(ancestors.get(0).equals("bill")); + + // Test query root + + expr = "gatherNodes(collection1, " + + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+ + "walk=\"from_s->from_s\"," + + "gather=\"to_s\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 3); + assertTrue(tuples.get(0).getString("node").equals("jim")); + assertTrue(tuples.get(1).getString("node").equals("max")); + assertTrue(tuples.get(2).getString("node").equals("sam")); + + + // Test query root scatter branches + + expr = "gatherNodes(collection1, " + + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+ + "walk=\"from_s->from_s\"," + + "gather=\"to_s\", scatter=\"branches, leaves\")"; + + stream = (GatherNodesStream)factory.constructStream(expr); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + assertTrue(tuples.size() == 4); + assertTrue(tuples.get(0).getString("node").equals("bill")); + assertTrue(tuples.get(0).getLong("level").equals(new Long(0))); + assertTrue(tuples.get(1).getString("node").equals("jim")); + assertTrue(tuples.get(1).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(2).getString("node").equals("max")); + assertTrue(tuples.get(2).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(3).getString("node").equals("sam")); + assertTrue(tuples.get(3).getLong("level").equals(new Long(1))); + + expr = "gatherNodes(collection1, " + + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+ + "walk=\"from_s->from_s\"," + + "gather=\"to_s\")"; + + String expr2 = "gatherNodes(collection1, " + + expr+","+ + "walk=\"node->from_s\"," + + "gather=\"to_s\")"; + + stream = (GatherNodesStream)factory.constructStream(expr2); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + + assertTrue(tuples.size() == 3); + assertTrue(tuples.get(0).getString("node").equals("ann")); + assertTrue(tuples.get(1).getString("node").equals("kip")); + assertTrue(tuples.get(2).getString("node").equals("steve")); + + + //Test two traversals in the same expression + String expr3 = "hashJoin("+expr2+", hashed="+expr2+", on=\"node\")"; + + HashJoinStream hstream = (HashJoinStream)factory.constructStream(expr3); + context = new StreamContext(); + context.setSolrClientCache(cache); + hstream.setStreamContext(context); + + tuples = getTuples(hstream); + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + + assertTrue(tuples.size() == 3); + assertTrue(tuples.get(0).getString("node").equals("ann")); + assertTrue(tuples.get(1).getString("node").equals("kip")); + assertTrue(tuples.get(2).getString("node").equals("steve")); + + //================================= + + + expr = "gatherNodes(collection1, " + + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+ + "walk=\"from_s->from_s\"," + + "gather=\"to_s\")"; + + expr2 = "gatherNodes(collection1, " + + expr+","+ + "walk=\"node->from_s\"," + + "gather=\"to_s\", scatter=\"branches, leaves\")"; + + stream = (GatherNodesStream)factory.constructStream(expr2); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + + + assertTrue(tuples.size() == 7); + assertTrue(tuples.get(0).getString("node").equals("ann")); + assertTrue(tuples.get(0).getLong("level").equals(new Long(2))); + assertTrue(tuples.get(1).getString("node").equals("bill")); + assertTrue(tuples.get(1).getLong("level").equals(new Long(0))); + assertTrue(tuples.get(2).getString("node").equals("jim")); + assertTrue(tuples.get(2).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(3).getString("node").equals("kip")); + assertTrue(tuples.get(3).getLong("level").equals(new Long(2))); + assertTrue(tuples.get(4).getString("node").equals("max")); + assertTrue(tuples.get(4).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(5).getString("node").equals("sam")); + assertTrue(tuples.get(5).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(6).getString("node").equals("steve")); + assertTrue(tuples.get(6).getLong("level").equals(new Long(2))); + + //Add a cycle from jim to bill + indexr(id, "6", "from_s", "jim", "to_s", "bill", "message_t", "Hello steve"); + indexr(id, "7", "from_s", "sam", "to_s", "bill", "message_t", "Hello steve"); + + commit(); + + expr = "gatherNodes(collection1, " + + "search(collection1, q=\"message_t:jim\", fl=\"from_s\", sort=\"from_s asc\"),"+ + "walk=\"from_s->from_s\"," + + "gather=\"to_s\", trackTraversal=\"true\")"; + + expr2 = "gatherNodes(collection1, " + + expr+","+ + "walk=\"node->from_s\"," + + "gather=\"to_s\", scatter=\"branches, leaves\", trackTraversal=\"true\")"; + + stream = (GatherNodesStream)factory.constructStream(expr2); + context = new StreamContext(); + context.setSolrClientCache(cache); + stream.setStreamContext(context); + + tuples = getTuples(stream); + Collections.sort(tuples, new FieldComparator("node", ComparatorOrder.ASCENDING)); + + assertTrue(tuples.size() == 7); + assertTrue(tuples.get(0).getString("node").equals("ann")); + assertTrue(tuples.get(0).getLong("level").equals(new Long(2))); + //Bill should now have one ancestor + assertTrue(tuples.get(1).getString("node").equals("bill")); + assertTrue(tuples.get(1).getLong("level").equals(new Long(0))); + assertTrue(tuples.get(1).getStrings("ancestors").size() == 2); + List anc = tuples.get(1).getStrings("ancestors"); + + Collections.sort(anc); + assertTrue(anc.get(0).equals("jim")); + assertTrue(anc.get(1).equals("sam")); + + assertTrue(tuples.get(2).getString("node").equals("jim")); + assertTrue(tuples.get(2).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(3).getString("node").equals("kip")); + assertTrue(tuples.get(3).getLong("level").equals(new Long(2))); + assertTrue(tuples.get(4).getString("node").equals("max")); + assertTrue(tuples.get(4).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(5).getString("node").equals("sam")); + assertTrue(tuples.get(5).getLong("level").equals(new Long(1))); + assertTrue(tuples.get(6).getString("node").equals("steve")); + assertTrue(tuples.get(6).getLong("level").equals(new Long(2))); + + cache.close(); + del("*:*"); + commit(); + } + + + protected List getTuples(TupleStream tupleStream) throws IOException { tupleStream.open(); - List tuples = new ArrayList(); + List tuples = new ArrayList(); for(Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) { tuples.add(t); }