diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java index f6b5818be72..cb46db4d8bc 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/StatsStream.java @@ -26,7 +26,7 @@ import java.util.Map.Entry; import java.util.stream.Collectors; import org.apache.solr.client.solrj.impl.CloudSolrClient; -import org.apache.solr.client.solrj.impl.CloudSolrClient.Builder; +import org.apache.solr.client.solrj.impl.HttpSolrClient; import org.apache.solr.client.solrj.io.SolrClientCache; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; @@ -61,6 +61,7 @@ public class StatsStream extends TupleStream implements Expressible { private Map metricMap; protected transient SolrClientCache cache; protected transient CloudSolrClient cloudSolrClient; + protected StreamContext streamContext; // Use StatsStream(String, String, SolrParams, Metric[] @Deprecated @@ -129,9 +130,12 @@ public class StatsStream extends TupleStream implements Expressible { else if(zkHostExpression.getParameter() instanceof StreamExpressionValue){ zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue(); } + + /* if(null == zkHost){ throw new IOException(String.format(Locale.ROOT,"invalid expression %s - zkHost not found for collection '%s'",expression,collectionName)); } + */ // metrics, optional - if not provided then why are you using this? Metric[] metrics = new Metric[metricExpressions.size()]; @@ -195,6 +199,7 @@ public class StatsStream extends TupleStream implements Expressible { } public void setStreamContext(StreamContext context) { + streamContext = context; cache = context.getSolrClientCache(); } @@ -203,32 +208,56 @@ public class StatsStream extends TupleStream implements Expressible { } public void open() throws IOException { - if(cache != null) { - cloudSolrClient = cache.getCloudSolrClient(zkHost); - } else { - cloudSolrClient = new Builder() - .withZkHost(zkHost) - .build(); - } - ModifiableSolrParams paramsLoc = new ModifiableSolrParams(this.params); addStats(paramsLoc, metrics); paramsLoc.set("stats", "true"); paramsLoc.set("rows", "0"); - QueryRequest request = new QueryRequest(paramsLoc); - try { - NamedList response = cloudSolrClient.request(request, collection); - this.tuple = getTuple(response); - } catch (Exception e) { - throw new IOException(e); + Map> shardsMap = (Map>)streamContext.get("shards"); + if(shardsMap == null) { + QueryRequest request = new QueryRequest(paramsLoc); + CloudSolrClient cloudSolrClient = cache.getCloudSolrClient(zkHost); + try { + NamedList response = cloudSolrClient.request(request, collection); + this.tuple = getTuple(response); + } catch (Exception e) { + throw new IOException(e); + } + } else { + List shards = shardsMap.get(collection); + HttpSolrClient client = cache.getHttpSolrClient(shards.get(0)); + + if(shards.size() > 1) { + String shardsParam = getShardString(shards); + paramsLoc.add("shards", shardsParam); + paramsLoc.add("distrib", "true"); + } + + QueryRequest request = new QueryRequest(paramsLoc); + try { + NamedList response = client.request(request); + this.tuple = getTuple(response); + } catch (Exception e) { + throw new IOException(e); + } } } - public void close() throws IOException { - if(cache == null) { - cloudSolrClient.close(); + private String getShardString(List shards) { + StringBuilder builder = new StringBuilder(); + for(String shard : shards) { + if(builder.length() > 0) { + builder.append(","); + } + builder.append(shard); } + return builder.toString(); + } + + + + public void close() throws IOException { + } public Tuple read() throws IOException { diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java index 006c5896a1c..9ea8ca40909 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java @@ -1695,38 +1695,138 @@ public class StreamExpressionTest extends SolrCloudTestCase { StreamExpression expression; TupleStream stream; List tuples; - - expression = StreamExpressionParser.parse("stats(collection1, q=*:*, sum(a_i), sum(a_f), min(a_i), min(a_f), max(a_i), max(a_f), avg(a_i), avg(a_f), count(*))"); - stream = factory.constructStream(expression); + StreamContext streamContext = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + try { + streamContext.setSolrClientCache(cache); + String expr = "stats(" + COLLECTIONORALIAS + ", q=*:*, sum(a_i), sum(a_f), min(a_i), min(a_f), max(a_i), max(a_f), avg(a_i), avg(a_f), count(*))"; + expression = StreamExpressionParser.parse(expr); + stream = factory.constructStream(expression); + stream.setStreamContext(streamContext); - tuples = getTuples(stream); + tuples = getTuples(stream); - assert(tuples.size() == 1); + assert (tuples.size() == 1); - //Test Long and Double Sums + //Test Long and Double Sums - Tuple tuple = tuples.get(0); + Tuple tuple = tuples.get(0); - Double sumi = tuple.getDouble("sum(a_i)"); - Double sumf = tuple.getDouble("sum(a_f)"); - Double mini = tuple.getDouble("min(a_i)"); - Double minf = tuple.getDouble("min(a_f)"); - Double maxi = tuple.getDouble("max(a_i)"); - Double maxf = tuple.getDouble("max(a_f)"); - Double avgi = tuple.getDouble("avg(a_i)"); - Double avgf = tuple.getDouble("avg(a_f)"); - Double count = tuple.getDouble("count(*)"); + Double sumi = tuple.getDouble("sum(a_i)"); + Double sumf = tuple.getDouble("sum(a_f)"); + Double mini = tuple.getDouble("min(a_i)"); + Double minf = tuple.getDouble("min(a_f)"); + Double maxi = tuple.getDouble("max(a_i)"); + Double maxf = tuple.getDouble("max(a_f)"); + Double avgi = tuple.getDouble("avg(a_i)"); + Double avgf = tuple.getDouble("avg(a_f)"); + Double count = tuple.getDouble("count(*)"); - assertTrue(sumi.longValue() == 70); - assertTrue(sumf.doubleValue() == 55.0D); - assertTrue(mini.doubleValue() == 0.0D); - assertTrue(minf.doubleValue() == 1.0D); - assertTrue(maxi.doubleValue() == 14.0D); - assertTrue(maxf.doubleValue() == 10.0D); - assertTrue(avgi.doubleValue() == 7.0D); - assertTrue(avgf.doubleValue() == 5.5D); - assertTrue(count.doubleValue() == 10); + assertTrue(sumi.longValue() == 70); + assertTrue(sumf.doubleValue() == 55.0D); + assertTrue(mini.doubleValue() == 0.0D); + assertTrue(minf.doubleValue() == 1.0D); + assertTrue(maxi.doubleValue() == 14.0D); + assertTrue(maxf.doubleValue() == 10.0D); + assertTrue(avgi.doubleValue() == 7.0D); + assertTrue(avgf.doubleValue() == 5.5D); + assertTrue(count.doubleValue() == 10); + + //Test with shards parameter + List shardUrls = TupleStream.getShards(cluster.getZkServer().getZkAddress(), COLLECTIONORALIAS, streamContext); + expr = "stats(myCollection, q=*:*, sum(a_i), sum(a_f), min(a_i), min(a_f), max(a_i), max(a_f), avg(a_i), avg(a_f), count(*))"; + Map> shardsMap = new HashMap(); + shardsMap.put("myCollection", shardUrls); + StreamContext context = new StreamContext(); + context.put("shards", shardsMap); + context.setSolrClientCache(cache); + stream = factory.constructStream(expr); + stream.setStreamContext(context); + + tuples = getTuples(stream); + + assert (tuples.size() == 1); + + //Test Long and Double Sums + + tuple = tuples.get(0); + + sumi = tuple.getDouble("sum(a_i)"); + sumf = tuple.getDouble("sum(a_f)"); + mini = tuple.getDouble("min(a_i)"); + minf = tuple.getDouble("min(a_f)"); + maxi = tuple.getDouble("max(a_i)"); + maxf = tuple.getDouble("max(a_f)"); + avgi = tuple.getDouble("avg(a_i)"); + avgf = tuple.getDouble("avg(a_f)"); + count = tuple.getDouble("count(*)"); + + assertTrue(sumi.longValue() == 70); + assertTrue(sumf.doubleValue() == 55.0D); + assertTrue(mini.doubleValue() == 0.0D); + assertTrue(minf.doubleValue() == 1.0D); + assertTrue(maxi.doubleValue() == 14.0D); + assertTrue(maxf.doubleValue() == 10.0D); + assertTrue(avgi.doubleValue() == 7.0D); + assertTrue(avgf.doubleValue() == 5.5D); + assertTrue(count.doubleValue() == 10); + + //Execersise the /stream hander + + //Add the shards http parameter for the myCollection + StringBuilder buf = new StringBuilder(); + for (String shardUrl : shardUrls) { + if (buf.length() > 0) { + buf.append(","); + } + buf.append(shardUrl); + } + + ModifiableSolrParams solrParams = new ModifiableSolrParams(); + solrParams.add("qt", "/stream"); + solrParams.add("expr", expr); + solrParams.add("myCollection.shards", buf.toString()); + SolrStream solrStream = new SolrStream(shardUrls.get(0), solrParams); + tuples = getTuples(solrStream); + assert (tuples.size() == 1); + + tuple =tuples.get(0); + + sumi = tuple.getDouble("sum(a_i)"); + sumf = tuple.getDouble("sum(a_f)"); + mini = tuple.getDouble("min(a_i)"); + minf = tuple.getDouble("min(a_f)"); + maxi = tuple.getDouble("max(a_i)"); + maxf = tuple.getDouble("max(a_f)"); + avgi = tuple.getDouble("avg(a_i)"); + avgf = tuple.getDouble("avg(a_f)"); + count = tuple.getDouble("count(*)"); + + assertTrue(sumi.longValue() == 70); + assertTrue(sumf.doubleValue() == 55.0D); + assertTrue(mini.doubleValue() == 0.0D); + assertTrue(minf.doubleValue() == 1.0D); + assertTrue(maxi.doubleValue() == 14.0D); + assertTrue(maxf.doubleValue() == 10.0D); + assertTrue(avgi.doubleValue() == 7.0D); + assertTrue(avgf.doubleValue() == 5.5D); + assertTrue(count.doubleValue() == 10); + //Add a negative test to prove that it cannot find slices if shards parameter is removed + + try { + ModifiableSolrParams solrParamsBad = new ModifiableSolrParams(); + solrParamsBad.add("qt", "/stream"); + solrParamsBad.add("expr", expr); + solrStream = new SolrStream(shardUrls.get(0), solrParamsBad); + tuples = getTuples(solrStream); + throw new Exception("Exception should have been thrown above"); + } catch (IOException e) { + assertTrue(e.getMessage().contains("Collection not found: myCollection")); + } + } finally { + cache.close(); + } } @Test