SOLR-8888: Add shortestPath Streaming Expression

This commit is contained in:
jbernste 2016-03-31 16:23:59 -04:00
parent 7263491d8e
commit 3500b45d6d
5 changed files with 1305 additions and 0 deletions

View File

@ -28,6 +28,7 @@ import java.util.Map.Entry;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.graph.ShortestPathStream;
import org.apache.solr.client.solrj.io.ops.ConcatOperation;
import org.apache.solr.client.solrj.io.ops.DistinctOperation;
import org.apache.solr.client.solrj.io.ops.GroupOperation;
@ -115,6 +116,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("complement", ComplementStream.class)
.withFunctionName("daemon", DaemonStream.class)
.withFunctionName("topic", TopicStream.class)
.withFunctionName("shortestPath", ShortestPathStream.class)
// metrics

View File

@ -0,0 +1,490 @@
package org.apache.solr.client.solrj.io.graph;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.ArrayList;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.solr.client.solrj.io.eq.MultipleFieldEqualitor;
import org.apache.solr.client.solrj.io.stream.*;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.eq.FieldEqualitor;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.SolrjNamedThreadFactory;
public class ShortestPathStream extends TupleStream implements Expressible {
private static final long serialVersionUID = 1;
private String fromNode;
private String toNode;
private String fromField;
private String toField;
private int joinBatchSize;
private int maxDepth;
private String zkHost;
private String collection;
private LinkedList<Tuple> shortestPaths = new LinkedList();
private boolean found;
private StreamContext streamContext;
private int threads;
private Map queryParams;
public ShortestPathStream(String zkHost,
String collection,
String fromNode,
String toNode,
String fromField,
String toField,
Map queryParams,
int joinBatchSize,
int threads,
int maxDepth) {
init(zkHost,
collection,
fromNode,
toNode,
fromField,
toField,
queryParams,
joinBatchSize,
threads,
maxDepth);
}
public ShortestPathStream(StreamExpression expression, StreamFactory factory) throws IOException {
String collectionName = factory.getValueOperand(expression, 0);
List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
// Collection Name
if(null == collectionName) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - collectionName expected as first operand",expression));
}
String fromNode = null;
StreamExpressionNamedParameter fromExpression = factory.getNamedOperand(expression, "from");
if(fromExpression == null) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - from param is required",expression));
} else {
fromNode = ((StreamExpressionValue)fromExpression.getParameter()).getValue();
}
String toNode = null;
StreamExpressionNamedParameter toExpression = factory.getNamedOperand(expression, "to");
if(toExpression == null) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - to param is required", expression));
} else {
toNode = ((StreamExpressionValue)toExpression.getParameter()).getValue();
}
String fromField = null;
String toField = null;
StreamExpressionNamedParameter edgeExpression = factory.getNamedOperand(expression, "edge");
if(edgeExpression == null) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - edge param is required", expression));
} else {
String edge = ((StreamExpressionValue)edgeExpression.getParameter()).getValue();
String[] fields = edge.split("=");
if(fields.length != 2) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - edge param separated by and = and must contain two fields", expression));
}
fromField = fields[0].trim();
toField = fields[1].trim();
}
int threads = 6;
StreamExpressionNamedParameter threadsExpression = factory.getNamedOperand(expression, "threads");
if(threadsExpression != null) {
threads = Integer.parseInt(((StreamExpressionValue)threadsExpression.getParameter()).getValue());
}
int partitionSize = 250;
StreamExpressionNamedParameter partitionExpression = factory.getNamedOperand(expression, "partitionSize");
if(partitionExpression != null) {
partitionSize = Integer.parseInt(((StreamExpressionValue)partitionExpression.getParameter()).getValue());
}
int maxDepth = 0;
StreamExpressionNamedParameter depthExpression = factory.getNamedOperand(expression, "maxDepth");
if(depthExpression == null) {
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - maxDepth param is required", expression));
} else {
maxDepth = Integer.parseInt(((StreamExpressionValue) depthExpression.getParameter()).getValue());
}
Map<String,String> params = new HashMap<String,String>();
for(StreamExpressionNamedParameter namedParam : namedParams){
if(!namedParam.getName().equals("zkHost") &&
!namedParam.getName().equals("to") &&
!namedParam.getName().equals("from") &&
!namedParam.getName().equals("edge") &&
!namedParam.getName().equals("maxDepth") &&
!namedParam.getName().equals("threads") &&
!namedParam.getName().equals("partitionSize"))
{
params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
}
}
// zkHost, optional - if not provided then will look into factory list to get
String zkHost = null;
if(null == zkHostExpression){
zkHost = factory.getCollectionZkHost(collectionName);
if(zkHost == null) {
zkHost = factory.getDefaultZkHost();
}
} else if(zkHostExpression.getParameter() instanceof StreamExpressionValue) {
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
}
if(null == zkHost){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
}
// We've got all the required items
init(zkHost, collectionName, fromNode, toNode, fromField, toField, params, partitionSize, threads, maxDepth);
}
private void init(String zkHost,
String collection,
String fromNode,
String toNode,
String fromField,
String toField,
Map queryParams,
int joinBatchSize,
int threads,
int maxDepth) {
this.zkHost = zkHost;
this.collection = collection;
this.fromNode = fromNode;
this.toNode = toNode;
this.fromField = fromField;
this.toField = toField;
this.queryParams = queryParams;
this.joinBatchSize = joinBatchSize;
this.threads = threads;
this.maxDepth = maxDepth;
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
// collection
expression.addParameter(collection);
Set<Map.Entry> entries = queryParams.entrySet();
// parameters
for(Map.Entry param : entries){
String value = param.getValue().toString();
// SOLR-8409: This is a special case where the params contain a " character
// Do note that in any other BASE streams with parameters where a " might come into play
// that this same replacement needs to take place.
value = value.replace("\"", "\\\"");
expression.addParameter(new StreamExpressionNamedParameter(param.getKey().toString(), value));
}
expression.addParameter(new StreamExpressionNamedParameter("zkHost", zkHost));
expression.addParameter(new StreamExpressionNamedParameter("maxDepth", Integer.toString(maxDepth)));
expression.addParameter(new StreamExpressionNamedParameter("threads", Integer.toString(threads)));
expression.addParameter(new StreamExpressionNamedParameter("partitionSize", Integer.toString(joinBatchSize)));
expression.addParameter(new StreamExpressionNamedParameter("from", fromNode));
expression.addParameter(new StreamExpressionNamedParameter("to", toNode));
expression.addParameter(new StreamExpressionNamedParameter("edge", fromField+"="+toField));
return expression;
}
public void setStreamContext(StreamContext context) {
this.streamContext = context;
}
public List<TupleStream> children() {
List<TupleStream> l = new ArrayList();
return l;
}
public void open() throws IOException {
List<Map<String,List<String>>> allVisited = new ArrayList();
Map visited = new HashMap();
visited.put(this.fromNode, null);
allVisited.add(visited);
int depth = 0;
Map<String, List<String>> nextVisited = null;
List<Edge> targets = new ArrayList();
ExecutorService threadPool = null;
try {
threadPool = ExecutorUtil.newMDCAwareFixedThreadPool(threads, new SolrjNamedThreadFactory("ShortestPathStream"));
//Breadth first search
TRAVERSE:
while (targets.size() == 0 && depth < maxDepth) {
Set<String> nodes = visited.keySet();
Iterator<String> it = nodes.iterator();
nextVisited = new HashMap();
int batchCount = 0;
List<String> queryNodes = new ArrayList();
List<Future> futures = new ArrayList();
JOIN:
//Queue up all the batches
while (it.hasNext()) {
String node = it.next();
queryNodes.add(node);
++batchCount;
if (batchCount == joinBatchSize || !it.hasNext()) {
try {
JoinRunner joinRunner = new JoinRunner(queryNodes);
Future<List<Edge>> future = threadPool.submit(joinRunner);
futures.add(future);
} catch (Exception e) {
throw new RuntimeException(e);
}
batchCount = 0;
queryNodes = new ArrayList();
}
}
try {
//Process the batches as they become available
OUTER:
for (Future<List<Edge>> future : futures) {
List<Edge> edges = future.get();
INNER:
for (Edge edge : edges) {
if (toNode.equals(edge.to)) {
targets.add(edge);
if(nextVisited.containsKey(edge.to)) {
List<String> parents = nextVisited.get(edge.to);
parents.add(edge.from);
} else {
List<String> parents = new ArrayList();
parents.add(edge.from);
nextVisited.put(edge.to, parents);
}
} else {
if (!cycle(edge.to, allVisited)) {
if(nextVisited.containsKey(edge.to)) {
List<String> parents = nextVisited.get(edge.to);
parents.add(edge.from);
} else {
List<String> parents = new ArrayList();
parents.add(edge.from);
nextVisited.put(edge.to, parents);
}
}
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
allVisited.add(nextVisited);
visited = nextVisited;
++depth;
}
} finally {
threadPool.shutdown();
}
Set<String> finalPaths = new HashSet();
if(targets.size() > 0) {
for(Edge edge : targets) {
List<LinkedList> paths = new ArrayList();
LinkedList<String> path = new LinkedList();
path.addFirst(edge.to);
paths.add(path);
//Walk back up the tree a collect the parent nodes.
INNER:
for (int i = allVisited.size() - 1; i >= 0; --i) {
Map<String, List<String>> v = allVisited.get(i);
Iterator<LinkedList> it = paths.iterator();
List newPaths = new ArrayList();
while(it.hasNext()) {
LinkedList p = it.next();
List<String> parents = v.get(p.peekFirst());
if (parents != null) {
for(String parent : parents) {
LinkedList newPath = new LinkedList();
newPath.addAll(p);
newPath.addFirst(parent);
newPaths.add(newPath);
}
paths = newPaths;
}
}
}
for(LinkedList p : paths) {
String s = p.toString();
if (!finalPaths.contains(s)){
Tuple shortestPath = new Tuple(new HashMap());
shortestPath.put("path", p);
shortestPaths.add(shortestPath);
finalPaths.add(s);
}
}
}
}
}
private class JoinRunner implements Callable<List<Edge>> {
private List<String> nodes;
private List<Edge> edges = new ArrayList();
public JoinRunner(List<String> nodes) {
this.nodes = nodes;
}
public List<Edge> call() {
Map joinParams = new HashMap();
String fl = fromField + "," + toField;
joinParams.putAll(queryParams);
joinParams.put("fl", fl);
joinParams.put("qt", "/export");
joinParams.put("sort", toField + " asc,"+fromField +" asc");
StringBuffer nodeQuery = new StringBuffer();
for(String node : nodes) {
nodeQuery.append(node).append(" ");
}
String q = fromField + ":(" + nodeQuery.toString().trim() + ")";
joinParams.put("q", q);
TupleStream stream = null;
try {
stream = new UniqueStream(new CloudSolrStream(zkHost, collection, joinParams), new MultipleFieldEqualitor(new FieldEqualitor(toField), new FieldEqualitor(fromField)));
stream.setStreamContext(streamContext);
stream.open();
BATCH:
while (true) {
Tuple tuple = stream.read();
if (tuple.EOF) {
break BATCH;
}
String _toNode = tuple.getString(toField);
String _fromNode = tuple.getString(fromField);
Edge edge = new Edge(_fromNode, _toNode);
edges.add(edge);
}
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
try {
stream.close();
} catch(Exception ce) {
throw new RuntimeException(ce);
}
}
return edges;
}
}
private class Edge {
private String from;
private String to;
public Edge(String from, String to) {
this.from = from;
this.to = to;
}
}
private boolean cycle(String node, List<Map<String,List<String>>> allVisited) {
//Check all visited trees for each level to see if we've encountered this node before.
for(Map<String, List<String>> visited : allVisited) {
if(visited.containsKey(node)) {
return true;
}
}
return false;
}
public void close() throws IOException {
this.found = false;
}
public Tuple read() throws IOException {
if(shortestPaths.size() > 0) {
found = true;
Tuple t = shortestPaths.removeFirst();
return t;
} else {
Map m = new HashMap();
m.put("EOF", true);
if(!found) {
m.put("sorry", "No path found");
}
return new Tuple(m);
}
}
public int getCost() {
return 0;
}
@Override
public StreamComparator getStreamSort() {
return null;
}
}

View File

@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Streaming Graph Traversals
**/
package org.apache.solr.client.solrj.io.graph;

View File

@ -0,0 +1,404 @@
package org.apache.solr.client.solrj.io.graph;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.*;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.cloud.AbstractFullDistribZkTestBase;
import org.apache.solr.cloud.AbstractZkTestCase;
import org.apache.solr.common.SolrInputDocument;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
/**
* All base tests will be done with CloudSolrStream. Under the covers CloudSolrStream uses SolrStream so
* SolrStream will get fully exercised through these tests.
*
**/
@Slow
@LuceneTestCase.SuppressCodecs({"Lucene3x", "Lucene40","Lucene41","Lucene42","Lucene45"})
public class GraphExpressionTest extends AbstractFullDistribZkTestBase {
private static final String SOLR_HOME = getFile("solrj" + File.separator + "solr").getAbsolutePath();
static {
schemaString = "schema-streaming.xml";
}
@BeforeClass
public static void beforeSuperClass() {
AbstractZkTestCase.SOLRHOME = new File(SOLR_HOME());
}
@AfterClass
public static void afterSuperClass() {
}
protected String getCloudSolrConfig() {
return "solrconfig-streaming.xml";
}
@Override
public String getSolrHome() {
return SOLR_HOME;
}
public static String SOLR_HOME() {
return SOLR_HOME;
}
@Before
@Override
public void setUp() throws Exception {
super.setUp();
// we expect this time of exception as shards go up and down...
//ignoreException(".*");
System.setProperty("numShards", Integer.toString(sliceCount));
}
@Override
@After
public void tearDown() throws Exception {
super.tearDown();
resetExceptionIgnores();
}
public GraphExpressionTest() {
super();
sliceCount = 2;
}
@Test
public void testAll() throws Exception{
assertNotNull(cloudClient);
handle.clear();
handle.put("timestamp", SKIPVAL);
waitForRecoveriesToFinish(false);
del("*:*");
commit();
testShortestPathStream();
}
private void testShortestPathStream() throws Exception {
indexr(id, "0", "from_s", "jim", "to_s", "mike", "predicate_s", "knows");
indexr(id, "1", "from_s", "jim", "to_s", "dave", "predicate_s", "knows");
indexr(id, "2", "from_s", "jim", "to_s", "stan", "predicate_s", "knows");
indexr(id, "3", "from_s", "dave", "to_s", "stan", "predicate_s", "knows");
indexr(id, "4", "from_s", "dave", "to_s", "bill", "predicate_s", "knows");
indexr(id, "5", "from_s", "dave", "to_s", "mike", "predicate_s", "knows");
indexr(id, "20", "from_s", "dave", "to_s", "alex", "predicate_s", "knows");
indexr(id, "21", "from_s", "alex", "to_s", "steve", "predicate_s", "knows");
indexr(id, "6", "from_s", "stan", "to_s", "alice", "predicate_s", "knows");
indexr(id, "7", "from_s", "stan", "to_s", "mary", "predicate_s", "knows");
indexr(id, "8", "from_s", "stan", "to_s", "dave", "predicate_s", "knows");
indexr(id, "10", "from_s", "mary", "to_s", "mike", "predicate_s", "knows");
indexr(id, "11", "from_s", "mary", "to_s", "max", "predicate_s", "knows");
indexr(id, "12", "from_s", "mary", "to_s", "jim", "predicate_s", "knows");
indexr(id, "13", "from_s", "mary", "to_s", "steve", "predicate_s", "knows");
commit();
List<Tuple> tuples = null;
Set<String> paths = null;
ShortestPathStream stream = null;
StreamContext context = new StreamContext();
SolrClientCache cache = new SolrClientCache();
context.setSolrClientCache(cache);
StreamFactory factory = new StreamFactory()
.withCollectionZkHost("collection1", zkServer.getZkAddress())
.withFunctionName("shortestPath", ShortestPathStream.class);
Map params = new HashMap();
params.put("fq", "predicate_s:knows");
stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " +
"from=\"jim\", " +
"to=\"steve\"," +
"edge=\"from_s=to_s\"," +
"fq=\"predicate_s:knows\","+
"threads=\"3\","+
"partitionSize=\"3\","+
"maxDepth=\"6\")");
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 2);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, dave, alex, steve]"));
assertTrue(paths.contains("[jim, stan, mary, steve]"));
//Test with batch size of 1
params.put("fq", "predicate_s:knows");
stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " +
"from=\"jim\", " +
"to=\"steve\"," +
"edge=\"from_s=to_s\"," +
"fq=\"predicate_s:knows\","+
"threads=\"3\","+
"partitionSize=\"1\","+
"maxDepth=\"6\")");
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 2);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, dave, alex, steve]"));
assertTrue(paths.contains("[jim, stan, mary, steve]"));
//Test with bad predicate
stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " +
"from=\"jim\", " +
"to=\"steve\"," +
"edge=\"from_s=to_s\"," +
"fq=\"predicate_s:crap\","+
"threads=\"3\","+
"partitionSize=\"3\","+
"maxDepth=\"6\")");
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 0);
//Test with depth 2
stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " +
"from=\"jim\", " +
"to=\"steve\"," +
"edge=\"from_s=to_s\"," +
"fq=\"predicate_s:knows\","+
"threads=\"3\","+
"partitionSize=\"3\","+
"maxDepth=\"2\")");
stream.setStreamContext(context);
tuples = getTuples(stream);
assertTrue(tuples.size() == 0);
//Take out alex
params.put("fq", "predicate_s:knows NOT to_s:alex");
stream = (ShortestPathStream)factory.constructStream("shortestPath(collection1, " +
"from=\"jim\", " +
"to=\"steve\"," +
"edge=\"from_s=to_s\"," +
"fq=\" predicate_s:knows NOT to_s:alex\","+
"threads=\"3\","+
"partitionSize=\"3\","+
"maxDepth=\"6\")");
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 1);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, stan, mary, steve]"));
cache.close();
del("*:*");
commit();
}
protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
tupleStream.open();
List<Tuple> tuples = new ArrayList<Tuple>();
for(Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) {
tuples.add(t);
}
tupleStream.close();
return tuples;
}
protected boolean assertOrder(List<Tuple> tuples, int... ids) throws Exception {
return assertOrderOf(tuples, "id", ids);
}
protected boolean assertOrderOf(List<Tuple> tuples, String fieldName, int... ids) throws Exception {
int i = 0;
for(int val : ids) {
Tuple t = tuples.get(i);
Long tip = (Long)t.get(fieldName);
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
protected boolean assertMapOrder(List<Tuple> tuples, int... ids) throws Exception {
int i = 0;
for(int val : ids) {
Tuple t = tuples.get(i);
List<Map> tip = t.getMaps("group");
int id = (int)tip.get(0).get("id");
if(id != val) {
throw new Exception("Found value:"+id+" expecting:"+val);
}
++i;
}
return true;
}
protected boolean assertFields(List<Tuple> tuples, String ... fields) throws Exception{
for(Tuple tuple : tuples){
for(String field : fields){
if(!tuple.fields.containsKey(field)){
throw new Exception(String.format(Locale.ROOT, "Expected field '%s' not found", field));
}
}
}
return true;
}
protected boolean assertNotFields(List<Tuple> tuples, String ... fields) throws Exception{
for(Tuple tuple : tuples){
for(String field : fields){
if(tuple.fields.containsKey(field)){
throw new Exception(String.format(Locale.ROOT, "Unexpected field '%s' found", field));
}
}
}
return true;
}
protected boolean assertGroupOrder(Tuple tuple, int... ids) throws Exception {
List<?> group = (List<?>)tuple.get("tuples");
int i=0;
for(int val : ids) {
Map<?,?> t = (Map<?,?>)group.get(i);
Long tip = (Long)t.get("id");
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
public boolean assertLong(Tuple tuple, String fieldName, long l) throws Exception {
long lv = (long)tuple.get(fieldName);
if(lv != l) {
throw new Exception("Longs not equal:"+l+" : "+lv);
}
return true;
}
public boolean assertString(Tuple tuple, String fieldName, String expected) throws Exception {
String actual = (String)tuple.get(fieldName);
if( (null == expected && null != actual) ||
(null != expected && null == actual) ||
(null != expected && !expected.equals(actual))){
throw new Exception("Longs not equal:"+expected+" : "+actual);
}
return true;
}
protected boolean assertMaps(List<Map> maps, int... ids) throws Exception {
if(maps.size() != ids.length) {
throw new Exception("Expected id count != actual map count:"+ids.length+":"+maps.size());
}
int i=0;
for(int val : ids) {
Map t = maps.get(i);
Long tip = (Long)t.get("id");
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
private boolean assertList(List list, Object... vals) throws Exception {
if(list.size() != vals.length) {
throw new Exception("Lists are not the same size:"+list.size() +" : "+vals.length);
}
for(int i=0; i<list.size(); i++) {
Object a = list.get(i);
Object b = vals[i];
if(!a.equals(b)) {
throw new Exception("List items not equals:"+a+" : "+b);
}
}
return true;
}
@Override
protected void indexr(Object... fields) throws Exception {
SolrInputDocument doc = getDoc(fields);
indexDoc(doc);
}
}

View File

@ -0,0 +1,387 @@
package org.apache.solr.client.solrj.io.graph;
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.cloud.AbstractFullDistribZkTestBase;
import org.apache.solr.cloud.AbstractZkTestCase;
import org.apache.solr.common.SolrInputDocument;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.Set;
import java.util.HashSet;
/**
* All base tests will be done with CloudSolrStream. Under the covers CloudSolrStream uses SolrStream so
* SolrStream will get fully exercised through these tests.
*
**/
@LuceneTestCase.Slow
@LuceneTestCase.SuppressCodecs({"Lucene3x", "Lucene40","Lucene41","Lucene42","Lucene45"})
public class GraphTest extends AbstractFullDistribZkTestBase {
private static final String SOLR_HOME = getFile("solrj" + File.separator + "solr").getAbsolutePath();
private StreamFactory streamFactory;
static {
schemaString = "schema-streaming.xml";
}
@BeforeClass
public static void beforeSuperClass() {
AbstractZkTestCase.SOLRHOME = new File(SOLR_HOME());
}
@AfterClass
public static void afterSuperClass() {
}
protected String getCloudSolrConfig() {
return "solrconfig-streaming.xml";
}
@Override
public String getSolrHome() {
return SOLR_HOME;
}
public static String SOLR_HOME() {
return SOLR_HOME;
}
@Before
@Override
public void setUp() throws Exception {
super.setUp();
// we expect this time of exception as shards go up and down...
//ignoreException(".*");
//System.setProperty("export.test", "true");
System.setProperty("numShards", Integer.toString(sliceCount));
}
@Override
@After
public void tearDown() throws Exception {
super.tearDown();
resetExceptionIgnores();
}
public GraphTest() {
super();
sliceCount = 2;
}
private void testShortestPathStream() throws Exception {
indexr(id, "0", "from_s", "jim", "to_s", "mike", "predicate_s", "knows");
indexr(id, "1", "from_s", "jim", "to_s", "dave", "predicate_s", "knows");
indexr(id, "2", "from_s", "jim", "to_s", "stan", "predicate_s", "knows");
indexr(id, "3", "from_s", "dave", "to_s", "stan", "predicate_s", "knows");
indexr(id, "4", "from_s", "dave", "to_s", "bill", "predicate_s", "knows");
indexr(id, "5", "from_s", "dave", "to_s", "mike", "predicate_s", "knows");
indexr(id, "20", "from_s", "dave", "to_s", "alex", "predicate_s", "knows");
indexr(id, "21", "from_s", "alex", "to_s", "steve", "predicate_s", "knows");
indexr(id, "6", "from_s", "stan", "to_s", "alice", "predicate_s", "knows");
indexr(id, "7", "from_s", "stan", "to_s", "mary", "predicate_s", "knows");
indexr(id, "8", "from_s", "stan", "to_s", "dave", "predicate_s", "knows");
indexr(id, "10", "from_s", "mary", "to_s", "mike", "predicate_s", "knows");
indexr(id, "11", "from_s", "mary", "to_s", "max", "predicate_s", "knows");
indexr(id, "12", "from_s", "mary", "to_s", "jim", "predicate_s", "knows");
indexr(id, "13", "from_s", "mary", "to_s", "steve", "predicate_s", "knows");
commit();
List<Tuple> tuples = null;
Set<String> paths = null;
ShortestPathStream stream = null;
String zkHost = zkServer.getZkAddress();
StreamContext context = new StreamContext();
SolrClientCache cache = new SolrClientCache();
context.setSolrClientCache(cache);
Map params = new HashMap();
params.put("fq", "predicate_s:knows");
stream = new ShortestPathStream(zkHost,
"collection1",
"jim",
"steve",
"from_s",
"to_s",
params,
20,
3,
6);
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 2);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, dave, alex, steve]"));
assertTrue(paths.contains("[jim, stan, mary, steve]"));
//Test with batch size of 1
params.put("fq", "predicate_s:knows");
stream = new ShortestPathStream(zkHost,
"collection1",
"jim",
"steve",
"from_s",
"to_s",
params,
1,
3,
6);
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 2);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, dave, alex, steve]"));
assertTrue(paths.contains("[jim, stan, mary, steve]"));
//Test with bad predicate
params.put("fq", "predicate_s:crap");
stream = new ShortestPathStream(zkHost,
"collection1",
"jim",
"steve",
"from_s",
"to_s",
params,
1,
3,
6);
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 0);
//Test with depth 2
params.put("fq", "predicate_s:knows");
stream = new ShortestPathStream(zkHost,
"collection1",
"jim",
"steve",
"from_s",
"to_s",
params,
1,
3,
2);
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 0);
//Take out alex
params.put("fq", "predicate_s:knows NOT to_s:alex");
stream = new ShortestPathStream(zkHost,
"collection1",
"jim",
"steve",
"from_s",
"to_s",
params,
10,
3,
6);
stream.setStreamContext(context);
paths = new HashSet();
tuples = getTuples(stream);
assertTrue(tuples.size() == 1);
for(Tuple tuple : tuples) {
paths.add(tuple.getStrings("path").toString());
}
assertTrue(paths.contains("[jim, stan, mary, steve]"));
cache.close();
del("*:*");
commit();
}
@Test
public void streamTests() throws Exception {
assertNotNull(cloudClient);
handle.clear();
handle.put("timestamp", SKIPVAL);
waitForRecoveriesToFinish(false);
del("*:*");
commit();
testShortestPathStream();
}
protected Map mapParams(String... vals) {
Map params = new HashMap();
String k = null;
for(String val : vals) {
if(k == null) {
k = val;
} else {
params.put(k, val);
k = null;
}
}
return params;
}
protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
tupleStream.open();
List<Tuple> tuples = new ArrayList();
for(;;) {
Tuple t = tupleStream.read();
if(t.EOF) {
break;
} else {
tuples.add(t);
}
}
tupleStream.close();
return tuples;
}
protected Tuple getTuple(TupleStream tupleStream) throws IOException {
tupleStream.open();
Tuple t = tupleStream.read();
tupleStream.close();
return t;
}
protected boolean assertOrder(List<Tuple> tuples, int... ids) throws Exception {
int i = 0;
for(int val : ids) {
Tuple t = tuples.get(i);
Long tip = (Long)t.get("id");
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
protected boolean assertGroupOrder(Tuple tuple, int... ids) throws Exception {
List group = (List)tuple.get("tuples");
int i=0;
for(int val : ids) {
Map t = (Map)group.get(i);
Long tip = (Long)t.get("id");
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
protected boolean assertMaps(List<Map> maps, int... ids) throws Exception {
if(maps.size() != ids.length) {
throw new Exception("Expected id count != actual map count:"+ids.length+":"+maps.size());
}
int i=0;
for(int val : ids) {
Map t = maps.get(i);
Long tip = (Long)t.get("id");
if(tip.intValue() != val) {
throw new Exception("Found value:"+tip.intValue()+" expecting:"+val);
}
++i;
}
return true;
}
public boolean assertLong(Tuple tuple, String fieldName, long l) throws Exception {
long lv = (long)tuple.get(fieldName);
if(lv != l) {
throw new Exception("Longs not equal:"+l+" : "+lv);
}
return true;
}
@Override
protected void indexr(Object... fields) throws Exception {
SolrInputDocument doc = getDoc(fields);
indexDoc(doc);
}
private void attachStreamFactory(TupleStream tupleStream) {
StreamContext streamContext = new StreamContext();
streamContext.setStreamFactory(streamFactory);
tupleStream.setStreamContext(streamContext);
}
}