From 96af65257dc250660e2810a0a8bcf86364921a07 Mon Sep 17 00:00:00 2001 From: Joel Bernstein Date: Fri, 23 Sep 2016 14:05:36 -0400 Subject: [PATCH] SOLR-9537: Support facet scoring with the scoreNodes expression --- .../client/solrj/io/stream/FacetStream.java | 8 ++ .../solrj/io/stream/ScoreNodesStream.java | 25 ++++++- .../solrj/io/graph/GraphExpressionTest.java | 75 +++++++++++++++++++ 3 files changed, 107 insertions(+), 1 deletion(-) diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java index 6802d0e62d2..4e239e62cef 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/FacetStream.java @@ -201,6 +201,14 @@ public class FacetStream extends TupleStream implements Expressible { init(collectionName, params, buckets, bucketSorts, metrics, limitInt, zkHost); } + public Bucket[] getBuckets() { + return this.buckets; + } + + public String getCollection() { + return this.collection; + } + private FieldComparator[] parseBucketSorts(String bucketSortString) throws IOException { String[] sorts = bucketSortString.split(","); diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java index 6a0cfc7abfd..41e719779c0 100644 --- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java +++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ScoreNodesStream.java @@ -65,6 +65,10 @@ public class ScoreNodesStream extends TupleStream implements Expressible private Map nodes = new HashMap(); private Iterator tuples; private String termFreq; + private boolean facet; + + private String bucket; + private String facetCollection; public ScoreNodesStream(TupleStream tupleStream, String nodeFreqField) throws IOException { init(tupleStream, nodeFreqField); @@ -98,6 +102,17 @@ public class ScoreNodesStream extends TupleStream implements Expressible private void init(TupleStream tupleStream, String termFreq) throws IOException{ this.stream = tupleStream; this.termFreq = termFreq; + if(stream instanceof FacetStream) { + FacetStream facetStream = (FacetStream) stream; + + if(facetStream.getBuckets().length != 1) { + throw new IOException("scoreNodes operates over a single bucket. Num buckets:"+facetStream.getBuckets().length); + } + + this.bucket = facetStream.getBuckets()[0].toString(); + this.facetCollection = facetStream.getCollection(); + this.facet = true; + } } @Override @@ -164,13 +179,21 @@ public class ScoreNodesStream extends TupleStream implements Expressible break; } + if(facet) { + //Turn the facet tuple into a node. + String nodeId = node.getString(bucket); + node.put("node", nodeId); + node.remove(bucket); + node.put("collection", facetCollection); + node.put("field", bucket); + } + if(!node.fields.containsKey("node")) { throw new IOException("node field not present in the Tuple"); } String nodeId = node.getString("node"); - nodes.put(nodeId, node); if(builder.length() > 0) { builder.append(","); diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java index dcd5ff40fbd..d6fc5143a64 100644 --- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java +++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/graph/GraphExpressionTest.java @@ -39,6 +39,7 @@ import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.ComparatorOrder; import org.apache.solr.client.solrj.io.comp.FieldComparator; import org.apache.solr.client.solrj.io.stream.CloudSolrStream; +import org.apache.solr.client.solrj.io.stream.FacetStream; import org.apache.solr.client.solrj.io.stream.HashJoinStream; import org.apache.solr.client.solrj.io.stream.ScoreNodesStream; import org.apache.solr.client.solrj.io.stream.SortStream; @@ -510,6 +511,80 @@ public class GraphExpressionTest extends SolrCloudTestCase { } + @Test + public void testScoreNodesFacetStream() throws Exception { + + + new UpdateRequest() + .add(id, "0", "basket_s", "basket1", "product_ss", "product1", "product_ss", "product3", "product_ss", "product5", "price_f", "1") + .add(id, "3", "basket_s", "basket2", "product_ss", "product1", "product_ss", "product6", "product_ss", "product7", "price_f", "1") + .add(id, "6", "basket_s", "basket3", "product_ss", "product4", "product_ss","product3", "product_ss","product1", "price_f", "1") + .add(id, "9", "basket_s", "basket4", "product_ss", "product4", "product_ss", "product3", "product_ss", "product1","price_f", "1") + .add(id, "12", "basket_s", "basket5", "product_ss", "product1", "price_f", "1") + .add(id, "13", "basket_s", "basket6", "product_ss", "product1", "price_f", "1") + .add(id, "14", "basket_s", "basket7", "product_ss", "product1", "price_f", "1") + .add(id, "15", "basket_s", "basket4", "product_ss", "product1", "price_f", "1") + .commit(cluster.getSolrClient(), COLLECTION); + + List tuples = null; + TupleStream stream = null; + StreamContext context = new StreamContext(); + SolrClientCache cache = new SolrClientCache(); + context.setSolrClientCache(cache); + + StreamFactory factory = new StreamFactory() + .withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress()) + .withDefaultZkHost(cluster.getZkServer().getZkAddress()) + .withFunctionName("gatherNodes", GatherNodesStream.class) + .withFunctionName("scoreNodes", ScoreNodesStream.class) + .withFunctionName("search", CloudSolrStream.class) + .withFunctionName("facet", FacetStream.class) + .withFunctionName("sort", SortStream.class) + .withFunctionName("count", CountMetric.class) + .withFunctionName("avg", MeanMetric.class) + .withFunctionName("sum", SumMetric.class) + .withFunctionName("min", MinMetric.class) + .withFunctionName("max", MaxMetric.class); + + String expr = "sort(by=\"nodeScore desc\",scoreNodes(facet(collection1, q=\"product_ss:product3\", buckets=\"product_ss\", bucketSorts=\"count(*) desc\", bucketSizeLimit=100, count(*))))"; + + stream = factory.constructStream(expr); + + context = new StreamContext(); + context.setSolrClientCache(cache); + + stream.setStreamContext(context); + tuples = getTuples(stream); + + //The highest scoring tuple will be the product searched for. + Tuple tuple = tuples.get(0); + assert(tuple.getString("node").equals("product3")); + assert(tuple.getLong("docFreq") == 3); + assert(tuple.getLong("count(*)") == 3); + + Tuple tuple0 = tuples.get(1); + assert(tuple0.getString("node").equals("product4")); + assert(tuple0.getLong("docFreq") == 2); + assert(tuple0.getLong("count(*)") == 2); + + Tuple tuple1 = tuples.get(2); + assert(tuple1.getString("node").equals("product1")); + assert(tuple1.getLong("docFreq") == 8); + assert(tuple1.getLong("count(*)") == 3); + + Tuple tuple2 = tuples.get(3); + assert(tuple2.getString("node").equals("product5")); + assert(tuple2.getLong("docFreq") == 1); + assert(tuple2.getLong("count(*)") == 1); + + + cache.close(); + } + + + + + @Test public void testGatherNodesFriendsStream() throws Exception {