SOLR-10254: significantTerms Streaming Expression should work in non-SolrCloud mode

This commit is contained in:
Joel Bernstein 2017-03-08 21:10:56 -05:00
parent 8756be0540
commit 682c6a7d51
4 changed files with 286 additions and 120 deletions

View File

@ -18,6 +18,7 @@ package org.apache.solr.handler;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@ -246,6 +247,7 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
int worker = params.getInt("workerID", 0);
int numWorkers = params.getInt("numWorkers", 1);
StreamContext context = new StreamContext();
context.put("shards", getCollectionShards(params));
context.workerID = worker;
context.numWorkers = numWorkers;
context.setSolrClientCache(clientCache);
@ -509,4 +511,29 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
return tuple;
}
}
private Map<String, List<String>> getCollectionShards(SolrParams params) {
Map<String, List<String>> collectionShards = new HashMap();
Iterator<String> paramsIt = params.getParameterNamesIterator();
while(paramsIt.hasNext()) {
String param = paramsIt.next();
if(param.indexOf(".shards") > -1) {
String collection = param.split("\\.")[0];
String shardString = params.get(param);
String[] shards = shardString.split(",");
List<String> shardList = new ArrayList();
for(String shard : shards) {
shardList.add(shard);
}
collectionShards.put(collection, shardList);
}
}
if(collectionShards.size() > 0) {
return collectionShards;
} else {
return null;
}
}
}

View File

@ -74,12 +74,9 @@ public class SignificantTermsStream extends TupleStream implements Expressible{
protected transient SolrClientCache cache;
protected transient boolean isCloseCache;
protected transient CloudSolrClient cloudSolrClient;
protected transient StreamContext streamContext;
protected ExecutorService executorService;
public SignificantTermsStream(String zkHost,
String collectionName,
Map params,
@ -168,12 +165,12 @@ public class SignificantTermsStream extends TupleStream implements Expressible{
String zkHost = null;
if(null == zkHostExpression){
zkHost = factory.getCollectionZkHost(collectionName);
}
else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){
} else if(zkHostExpression.getParameter() instanceof StreamExpressionValue) {
zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
}
if(null == zkHost){
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName));
if(zkHost == null){
zkHost = factory.getDefaultZkHost();
}
// We've got all the required items
@ -238,47 +235,13 @@ public class SignificantTermsStream extends TupleStream implements Expressible{
isCloseCache = false;
}
this.cloudSolrClient = this.cache.getCloudSolrClient(zkHost);
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("FeaturesSelectionStream"));
this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrjNamedThreadFactory("SignificantTermsStream"));
}
public List<TupleStream> children() {
return null;
}
private List<String> getShardUrls() throws IOException {
try {
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
Collection<Slice> slices = CloudSolrStream.getSlices(this.collection, zkStateReader, false);
ClusterState clusterState = zkStateReader.getClusterState();
Set<String> liveNodes = clusterState.getLiveNodes();
List<String> baseUrls = new ArrayList<>();
for(Slice slice : slices) {
Collection<Replica> replicas = slice.getReplicas();
List<Replica> shuffler = new ArrayList<>();
for(Replica replica : replicas) {
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName())) {
shuffler.add(replica);
}
}
Collections.shuffle(shuffler, new Random());
Replica rep = shuffler.get(0);
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
String url = zkProps.getCoreUrl();
baseUrls.add(url);
}
return baseUrls;
} catch (Exception e) {
throw new IOException(e);
}
}
private List<Future<NamedList>> callShards(List<String> baseUrls) throws IOException {
List<Future<NamedList>> futures = new ArrayList<>();
@ -326,7 +289,7 @@ public class SignificantTermsStream extends TupleStream implements Expressible{
Map<String, int[]> mergeFreqs = new HashMap<>();
long numDocs = 0;
long resultCount = 0;
for (Future<NamedList> getTopTermsCall : callShards(getShardUrls())) {
for (Future<NamedList> getTopTermsCall : callShards(getShards(zkHost, collection, streamContext))) {
NamedList resp = getTopTermsCall.get();
List<String> terms = (List<String>)resp.get("sterms");

View File

@ -19,9 +19,16 @@ package org.apache.solr.client.solrj.io.stream;
import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.Map;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
@ -29,6 +36,14 @@ import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.IteratorWriter;
import org.apache.solr.common.MapWriter;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.cloud.Aliases;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.DocCollection;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.cloud.ZkCoreNodeProps;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.common.util.StrUtils;
public abstract class TupleStream implements Closeable, Serializable, MapWriter {
@ -84,4 +99,83 @@ public abstract class TupleStream implements Closeable, Serializable, MapWriter
public UUID getStreamNodeId(){
return streamNodeId;
}
public static List<String> getShards(String zkHost,
String collection,
StreamContext streamContext)
throws IOException {
Map<String, List<String>> shardsMap = null;
List<String> shards = new ArrayList();
if(streamContext != null) {
shardsMap = (Map<String, List<String>>)streamContext.get("shards");
}
if(shardsMap != null) {
//Manual Sharding
shards = shardsMap.get(collection);
} else {
//SolrCloud Sharding
CloudSolrClient cloudSolrClient = streamContext.getSolrClientCache().getCloudSolrClient(zkHost);
ZkStateReader zkStateReader = cloudSolrClient.getZkStateReader();
ClusterState clusterState = zkStateReader.getClusterState();
Collection<Slice> slices = getSlices(collection, zkStateReader, true);
Set<String> liveNodes = clusterState.getLiveNodes();
for(Slice slice : slices) {
Collection<Replica> replicas = slice.getReplicas();
List<Replica> shuffler = new ArrayList<>();
for(Replica replica : replicas) {
if(replica.getState() == Replica.State.ACTIVE && liveNodes.contains(replica.getNodeName()))
shuffler.add(replica);
}
Collections.shuffle(shuffler, new Random());
Replica rep = shuffler.get(0);
ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
String url = zkProps.getCoreUrl();
shards.add(url);
}
}
return shards;
}
public static Collection<Slice> getSlices(String collectionName,
ZkStateReader zkStateReader,
boolean checkAlias) throws IOException {
ClusterState clusterState = zkStateReader.getClusterState();
Map<String, DocCollection> collectionsMap = clusterState.getCollectionsMap();
// Check collection case sensitive
if(collectionsMap.containsKey(collectionName)) {
return collectionsMap.get(collectionName).getActiveSlices();
}
// Check collection case insensitive
for(String collectionMapKey : collectionsMap.keySet()) {
if(collectionMapKey.equalsIgnoreCase(collectionName)) {
return collectionsMap.get(collectionMapKey).getActiveSlices();
}
}
if(checkAlias) {
// check for collection alias
Aliases aliases = zkStateReader.getAliases();
String alias = aliases.getCollectionAlias(collectionName);
if (alias != null) {
Collection<Slice> slices = new ArrayList<>();
List<String> aliasList = StrUtils.splitSmart(alias, ",", true);
for (String aliasCollectionName : aliasList) {
// Add all active slices for this alias collection
slices.addAll(collectionsMap.get(aliasCollectionName).getActiveSlices());
}
return slices;
}
}
throw new IOException("Slices not found for " + collectionName);
}
}

View File

@ -335,7 +335,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
tuples = getTuples(stream);
assert(tuples.size() == 4);
assertOrder(tuples, 4,3,1,2);
assertOrder(tuples, 4, 3, 1, 2);
// Basic w/multi comp
expression = StreamExpressionParser.parse("unique(search(" + COLLECTIONORALIAS + ", q=*:*, fl=\"id,a_s,a_i,a_f\", sort=\"a_f asc, a_i asc\"), over=\"a_f, a_i\")");
@ -1577,7 +1577,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
List<Tuple> tuples = getTuples(pstream);
assert(tuples.size() == 5);
assertOrder(tuples, 0,1,3,4,6);
assertOrder(tuples, 0, 1, 3, 4, 6);
//Test the eofTuples
@ -4712,8 +4712,6 @@ public class StreamExpressionTest extends SolrCloudTestCase {
@Test
public void testSignificantTermsStream() throws Exception {
Assume.assumeTrue(!useAlias);
UpdateRequest updateRequest = new UpdateRequest();
for (int i = 0; i < 5000; i++) {
updateRequest.add(id, "a"+i, "test_t", "a b c d m l");
@ -4742,106 +4740,186 @@ public class StreamExpressionTest extends SolrCloudTestCase {
StreamFactory factory = new StreamFactory()
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
.withDefaultZkHost(cluster.getZkServer().getZkAddress())
.withFunctionName("significantTerms", SignificantTermsStream.class);
String significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
StreamContext streamContext = new StreamContext();
SolrClientCache cache = new SolrClientCache();
streamContext.setSolrClientCache(cache);
try {
assert(tuples.size() == 3);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
String significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 3);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("m"));
assertTrue(tuples.get(1).getLong("background") == 5500);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("m"));
assertTrue(tuples.get(1).getLong("background") == 5500);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
assertTrue(tuples.get(2).get("term").equals("d"));
assertTrue(tuples.get(2).getLong("background") == 5600);
assertTrue(tuples.get(2).getLong("foreground") == 5000);
assertTrue(tuples.get(2).get("term").equals("d"));
assertTrue(tuples.get(2).getLong("background") == 5600);
assertTrue(tuples.get(2).getLong("foreground") == 5000);
//Test maxDocFreq
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=2650, minTermLength=1)";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
//Test maxDocFreq
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=2650, minTermLength=1)";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert(tuples.size() == 1);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assert (tuples.size() == 1);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
//Test maxDocFreq percentage
//Test maxDocFreq percentage
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=\".45\", minTermLength=1)";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
assert(tuples.size() == 1);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, maxDocFreq=\".45\", minTermLength=1)";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 1);
assertTrue(tuples.get(0).get("term").equals("l"));
assertTrue(tuples.get(0).getLong("background") == 5000);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
//Test min doc freq
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
//Test min doc freq
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert(tuples.size() == 3);
assert (tuples.size() == 3);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
assertTrue(tuples.get(2).get("term").equals("c"));
assertTrue(tuples.get(2).getLong("background") == 5900);
assertTrue(tuples.get(2).getLong("foreground") == 5000);
assertTrue(tuples.get(2).get("term").equals("c"));
assertTrue(tuples.get(2).getLong("background") == 5900);
assertTrue(tuples.get(2).getLong("foreground") == 5000);
//Test min doc freq percent
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\".478\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
//Test min doc freq percent
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=3, minDocFreq=\".478\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert(tuples.size() == 1);
assert (tuples.size() == 1);
assertTrue(tuples.get(0).get("term").equals("c"));
assertTrue(tuples.get(0).getLong("background") == 5900);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(0).get("term").equals("c"));
assertTrue(tuples.get(0).getLong("background") == 5900);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
//Test limit
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 2);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
//Test term length
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=2)";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(streamContext);
tuples = getTuples(stream);
assert (tuples.size() == 0);
//Test limit
//Test with shards parameter
List<String> shardUrls = TupleStream.getShards(cluster.getZkServer().getZkAddress(), COLLECTIONORALIAS, streamContext);
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
Map<String, List<String>> shardsMap = new HashMap();
shardsMap.put("myCollection", shardUrls);
StreamContext context = new StreamContext();
context.put("shards", shardsMap);
context.setSolrClientCache(cache);
significantTerms = "significantTerms(myCollection, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=1, maxDocFreq=\".5\")";
stream = factory.constructStream(significantTerms);
stream.setStreamContext(context);
tuples = getTuples(stream);
assert(tuples.size() == 2);
assert (tuples.size() == 2);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
//Execersise the /stream hander
//Add the shards http parameter for the myCollection
StringBuilder buf = new StringBuilder();
for (String shardUrl : shardUrls) {
if (buf.length() > 0) {
buf.append(",");
}
buf.append(shardUrl);
}
ModifiableSolrParams solrParams = new ModifiableSolrParams();
solrParams.add("qt", "/stream");
solrParams.add("expr", significantTerms);
solrParams.add("myCollection.shards", buf.toString());
SolrStream solrStream = new SolrStream(shardUrls.get(0), solrParams);
tuples = getTuples(solrStream);
assert (tuples.size() == 2);
assertTrue(tuples.get(0).get("term").equals("m"));
assertTrue(tuples.get(0).getLong("background") == 5500);
assertTrue(tuples.get(0).getLong("foreground") == 5000);
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
//Add a negative test to prove that it cannot find slices if shards parameter is removed
try {
ModifiableSolrParams solrParamsBad = new ModifiableSolrParams();
solrParamsBad.add("qt", "/stream");
solrParamsBad.add("expr", significantTerms);
solrStream = new SolrStream(shardUrls.get(0), solrParamsBad);
tuples = getTuples(solrStream);
throw new Exception("Exception should have been thrown above");
} catch (IOException e) {
assertTrue(e.getMessage().contains("Slices not found for myCollection"));
}
} finally {
cache.close();
}
assertTrue(tuples.get(1).get("term").equals("d"));
assertTrue(tuples.get(1).getLong("background") == 5600);
assertTrue(tuples.get(1).getLong("foreground") == 5000);
//Test term length
significantTerms = "significantTerms(collection1, q=\"id:a*\", field=\"test_t\", limit=2, minDocFreq=\"2700\", minTermLength=2)";
stream = factory.constructStream(significantTerms);
tuples = getTuples(stream);
assert(tuples.size() == 0);
}
@Test
public void testComplementStream() throws Exception {
@ -4920,12 +4998,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
protected List<Tuple> getTuples(TupleStream tupleStream) throws IOException {
tupleStream.open();
List<Tuple> tuples = new ArrayList<Tuple>();
for(Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) {
tuples.add(t);
try {
tupleStream.open();
for (Tuple t = tupleStream.read(); !t.EOF; t = tupleStream.read()) {
tuples.add(t);
}
} finally {
tupleStream.close();
}
tupleStream.close();
return tuples;
}
protected boolean assertOrder(List<Tuple> tuples, int... ids) throws Exception {