SOLR-9537: Support facet scoring with the scoreNodes expression

This commit is contained in:
Joel Bernstein 2016-09-23 14:05:36 -04:00
parent b8dd3be93a
commit 96af65257d
3 changed files with 107 additions and 1 deletions

View File

@ -201,6 +201,14 @@ public class FacetStream extends TupleStream implements Expressible {
init(collectionName, params, buckets, bucketSorts, metrics, limitInt, zkHost);
}
public Bucket[] getBuckets() {
return this.buckets;
}
public String getCollection() {
return this.collection;
}
private FieldComparator[] parseBucketSorts(String bucketSortString) throws IOException {
String[] sorts = bucketSortString.split(",");

View File

@ -65,6 +65,10 @@ public class ScoreNodesStream extends TupleStream implements Expressible
private Map<String, Tuple> nodes = new HashMap();
private Iterator<Tuple> tuples;
private String termFreq;
private boolean facet;
private String bucket;
private String facetCollection;
public ScoreNodesStream(TupleStream tupleStream, String nodeFreqField) throws IOException {
init(tupleStream, nodeFreqField);
@ -98,6 +102,17 @@ public class ScoreNodesStream extends TupleStream implements Expressible
private void init(TupleStream tupleStream, String termFreq) throws IOException{
this.stream = tupleStream;
this.termFreq = termFreq;
if(stream instanceof FacetStream) {
FacetStream facetStream = (FacetStream) stream;
if(facetStream.getBuckets().length != 1) {
throw new IOException("scoreNodes operates over a single bucket. Num buckets:"+facetStream.getBuckets().length);
}
this.bucket = facetStream.getBuckets()[0].toString();
this.facetCollection = facetStream.getCollection();
this.facet = true;
}
}
@Override
@ -164,13 +179,21 @@ public class ScoreNodesStream extends TupleStream implements Expressible
break;
}
if(facet) {
//Turn the facet tuple into a node.
String nodeId = node.getString(bucket);
node.put("node", nodeId);
node.remove(bucket);
node.put("collection", facetCollection);
node.put("field", bucket);
}
if(!node.fields.containsKey("node")) {
throw new IOException("node field not present in the Tuple");
}
String nodeId = node.getString("node");
nodes.put(nodeId, node);
if(builder.length() > 0) {
builder.append(",");

View File

@ -39,6 +39,7 @@ import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
import org.apache.solr.client.solrj.io.comp.FieldComparator;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.FacetStream;
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
import org.apache.solr.client.solrj.io.stream.SortStream;
@ -510,6 +511,80 @@ public class GraphExpressionTest extends SolrCloudTestCase {
}
@Test
public void testScoreNodesFacetStream() throws Exception {
new UpdateRequest()
.add(id, "0", "basket_s", "basket1", "product_ss", "product1", "product_ss", "product3", "product_ss", "product5", "price_f", "1")
.add(id, "3", "basket_s", "basket2", "product_ss", "product1", "product_ss", "product6", "product_ss", "product7", "price_f", "1")
.add(id, "6", "basket_s", "basket3", "product_ss", "product4", "product_ss","product3", "product_ss","product1", "price_f", "1")
.add(id, "9", "basket_s", "basket4", "product_ss", "product4", "product_ss", "product3", "product_ss", "product1","price_f", "1")
.add(id, "12", "basket_s", "basket5", "product_ss", "product1", "price_f", "1")
.add(id, "13", "basket_s", "basket6", "product_ss", "product1", "price_f", "1")
.add(id, "14", "basket_s", "basket7", "product_ss", "product1", "price_f", "1")
.add(id, "15", "basket_s", "basket4", "product_ss", "product1", "price_f", "1")
.commit(cluster.getSolrClient(), COLLECTION);
List<Tuple> tuples = null;
TupleStream stream = null;
StreamContext context = new StreamContext();
SolrClientCache cache = new SolrClientCache();
context.setSolrClientCache(cache);
StreamFactory factory = new StreamFactory()
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
.withDefaultZkHost(cluster.getZkServer().getZkAddress())
.withFunctionName("gatherNodes", GatherNodesStream.class)
.withFunctionName("scoreNodes", ScoreNodesStream.class)
.withFunctionName("search", CloudSolrStream.class)
.withFunctionName("facet", FacetStream.class)
.withFunctionName("sort", SortStream.class)
.withFunctionName("count", CountMetric.class)
.withFunctionName("avg", MeanMetric.class)
.withFunctionName("sum", SumMetric.class)
.withFunctionName("min", MinMetric.class)
.withFunctionName("max", MaxMetric.class);
String expr = "sort(by=\"nodeScore desc\",scoreNodes(facet(collection1, q=\"product_ss:product3\", buckets=\"product_ss\", bucketSorts=\"count(*) desc\", bucketSizeLimit=100, count(*))))";
stream = factory.constructStream(expr);
context = new StreamContext();
context.setSolrClientCache(cache);
stream.setStreamContext(context);
tuples = getTuples(stream);
//The highest scoring tuple will be the product searched for.
Tuple tuple = tuples.get(0);
assert(tuple.getString("node").equals("product3"));
assert(tuple.getLong("docFreq") == 3);
assert(tuple.getLong("count(*)") == 3);
Tuple tuple0 = tuples.get(1);
assert(tuple0.getString("node").equals("product4"));
assert(tuple0.getLong("docFreq") == 2);
assert(tuple0.getLong("count(*)") == 2);
Tuple tuple1 = tuples.get(2);
assert(tuple1.getString("node").equals("product1"));
assert(tuple1.getLong("docFreq") == 8);
assert(tuple1.getLong("count(*)") == 3);
Tuple tuple2 = tuples.get(3);
assert(tuple2.getString("node").equals("product5"));
assert(tuple2.getLong("docFreq") == 1);
assert(tuple2.getLong("count(*)") == 1);
cache.close();
}
@Test
public void testGatherNodesFriendsStream() throws Exception {