diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java index 0d305fdbff8..814b69c0a9f 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java @@ -165,7 +165,13 @@ public class ScoreNodesStream extends TupleStream implements Expressible break; } + if(!node.fields.containsKey("node")) { + throw new IOException("node field not present in the Tuple"); + } + String nodeId = node.getString("node"); + + nodes.put(nodeId, node); if(builder.length() > 0) { builder.append(","); @@ -202,6 +208,9 @@ public class ScoreNodesStream extends TupleStream implements Expressible String term = terms.getName(t); Number docFreq = terms.get(term); Tuple tuple = nodes.get(term); + if(!tuple.fields.containsKey(termFreq)) { + throw new Exception("termFreq field not present in the Tuple"); + } Number termFreqValue = (Number)tuple.get(termFreq); float score = termFreqValue.floatValue() * (float) (Math.log((numDocs + 1) / (docFreq.doubleValue() + 1)) + 1.0); tuple.put("nodeScore", score); diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java index 9dbd70631c5..a141b7341ab 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java @@ -392,26 +392,25 @@ public class GraphExpressionTest extends SolrCloudTestCase { new UpdateRequest() - .add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "20") - .add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "30") - .add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "1") - .add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "2") - .add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "5") - .add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "10") - .add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "20") - .add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "10") - .add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "10") - .add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "40") - .add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "10") - .add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "10") - .add(id, "12", "basket_s", "basket5", "product_s", "product1", "price_f", "10") - .add(id, "13", "basket_s", "basket6", "product_s", "product1", "price_f", "10") - .add(id, "14", "basket_s", "basket7", "product_s", "product1", "price_f", "10") - .add(id, "15", "basket_s", "basket4", "product_s", "product1", "price_f", "10") + .add(id, "0", "basket_s", "basket1", "product_s", "product1", "price_f", "1") + .add(id, "1", "basket_s", "basket1", "product_s", "product3", "price_f", "1") + .add(id, "2", "basket_s", "basket1", "product_s", "product5", "price_f", "100") + .add(id, "3", "basket_s", "basket2", "product_s", "product1", "price_f", "1") + .add(id, "4", "basket_s", "basket2", "product_s", "product6", "price_f", "1") + .add(id, "5", "basket_s", "basket2", "product_s", "product7", "price_f", "1") + .add(id, "6", "basket_s", "basket3", "product_s", "product4", "price_f", "1") + .add(id, "7", "basket_s", "basket3", "product_s", "product3", "price_f", "1") + .add(id, "8", "basket_s", "basket3", "product_s", "product1", "price_f", "1") + .add(id, "9", "basket_s", "basket4", "product_s", "product4", "price_f", "1") + .add(id, "10", "basket_s", "basket4", "product_s", "product3", "price_f", "1") + .add(id, "11", "basket_s", "basket4", "product_s", "product1", "price_f", "1") + .add(id, "12", "basket_s", "basket5", "product_s", "product1", "price_f", "1") + .add(id, "13", "basket_s", "basket6", "product_s", "product1", "price_f", "1") + .add(id, "14", "basket_s", "basket7", "product_s", "product1", "price_f", "1") + .add(id, "15", "basket_s", "basket4", "product_s", "product1", "price_f", "1") .commit(cluster.getSolrClient(), COLLECTION); List tuples = null; - Set paths = null; TupleStream stream = null; StreamContext context = new StreamContext(); SolrClientCache cache = new SolrClientCache(); @@ -470,6 +469,43 @@ public class GraphExpressionTest extends SolrCloudTestCase { assert(tuple2.getLong("docFreq") == 1); assert(tuple2.getLong("count(*)") == 1); + + //Test using a different termFreq field then the default count(*) + expr2 = "sort(by=\"nodeScore desc\", " + + "scoreNodes(termFreq=\"avg(price_f)\",gatherNodes(collection1, " + + expr+","+ + "walk=\"node->basket_s\"," + + "gather=\"product_s\", " + + "count(*), " + + "avg(price_f), " + + "sum(price_f), " + + "min(price_f), " + + "max(price_f))))"; + + stream = factory.constructStream(expr2); + + context = new StreamContext(); + context.setSolrClientCache(cache); + + stream.setStreamContext(context); + + tuples = getTuples(stream); + + tuple0 = tuples.get(0); + assert(tuple0.getString("node").equals("product5")); + assert(tuple0.getLong("docFreq") == 1); + assert(tuple0.getDouble("avg(price_f)") == 100); + + tuple1 = tuples.get(1); + assert(tuple1.getString("node").equals("product4")); + assert(tuple1.getLong("docFreq") == 2); + assert(tuple1.getDouble("avg(price_f)") == 1); + + tuple2 = tuples.get(2); + assert(tuple2.getString("node").equals("product1")); + assert(tuple2.getLong("docFreq") == 8); + assert(tuple2.getDouble("avg(price_f)") == 1); + cache.close(); }