mirror of
https://github.com/apache/lucene.git
synced 2025-02-12 21:15:19 +00:00
SOLR-9537: Support facet scoring with the scoreNodes expression
This commit is contained in:
parent
b8dd3be93a
commit
96af65257d
@ -201,6 +201,14 @@ public class FacetStream extends TupleStream implements Expressible {
|
||||
init(collectionName, params, buckets, bucketSorts, metrics, limitInt, zkHost);
|
||||
}
|
||||
|
||||
public Bucket[] getBuckets() {
|
||||
return this.buckets;
|
||||
}
|
||||
|
||||
public String getCollection() {
|
||||
return this.collection;
|
||||
}
|
||||
|
||||
private FieldComparator[] parseBucketSorts(String bucketSortString) throws IOException {
|
||||
|
||||
String[] sorts = bucketSortString.split(",");
|
||||
|
@ -65,6 +65,10 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
||||
private Map<String, Tuple> nodes = new HashMap();
|
||||
private Iterator<Tuple> tuples;
|
||||
private String termFreq;
|
||||
private boolean facet;
|
||||
|
||||
private String bucket;
|
||||
private String facetCollection;
|
||||
|
||||
public ScoreNodesStream(TupleStream tupleStream, String nodeFreqField) throws IOException {
|
||||
init(tupleStream, nodeFreqField);
|
||||
@ -98,6 +102,17 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
||||
private void init(TupleStream tupleStream, String termFreq) throws IOException{
|
||||
this.stream = tupleStream;
|
||||
this.termFreq = termFreq;
|
||||
if(stream instanceof FacetStream) {
|
||||
FacetStream facetStream = (FacetStream) stream;
|
||||
|
||||
if(facetStream.getBuckets().length != 1) {
|
||||
throw new IOException("scoreNodes operates over a single bucket. Num buckets:"+facetStream.getBuckets().length);
|
||||
}
|
||||
|
||||
this.bucket = facetStream.getBuckets()[0].toString();
|
||||
this.facetCollection = facetStream.getCollection();
|
||||
this.facet = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -164,13 +179,21 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
||||
break;
|
||||
}
|
||||
|
||||
if(facet) {
|
||||
//Turn the facet tuple into a node.
|
||||
String nodeId = node.getString(bucket);
|
||||
node.put("node", nodeId);
|
||||
node.remove(bucket);
|
||||
node.put("collection", facetCollection);
|
||||
node.put("field", bucket);
|
||||
}
|
||||
|
||||
if(!node.fields.containsKey("node")) {
|
||||
throw new IOException("node field not present in the Tuple");
|
||||
}
|
||||
|
||||
String nodeId = node.getString("node");
|
||||
|
||||
|
||||
nodes.put(nodeId, node);
|
||||
if(builder.length() > 0) {
|
||||
builder.append(",");
|
||||
|
@ -39,6 +39,7 @@ import org.apache.solr.client.solrj.io.Tuple;
|
||||
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
|
||||
import org.apache.solr.client.solrj.io.comp.FieldComparator;
|
||||
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
|
||||
import org.apache.solr.client.solrj.io.stream.FacetStream;
|
||||
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
|
||||
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
|
||||
import org.apache.solr.client.solrj.io.stream.SortStream;
|
||||
@ -510,6 +511,80 @@ public class GraphExpressionTest extends SolrCloudTestCase {
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testScoreNodesFacetStream() throws Exception {
|
||||
|
||||
|
||||
new UpdateRequest()
|
||||
.add(id, "0", "basket_s", "basket1", "product_ss", "product1", "product_ss", "product3", "product_ss", "product5", "price_f", "1")
|
||||
.add(id, "3", "basket_s", "basket2", "product_ss", "product1", "product_ss", "product6", "product_ss", "product7", "price_f", "1")
|
||||
.add(id, "6", "basket_s", "basket3", "product_ss", "product4", "product_ss","product3", "product_ss","product1", "price_f", "1")
|
||||
.add(id, "9", "basket_s", "basket4", "product_ss", "product4", "product_ss", "product3", "product_ss", "product1","price_f", "1")
|
||||
.add(id, "12", "basket_s", "basket5", "product_ss", "product1", "price_f", "1")
|
||||
.add(id, "13", "basket_s", "basket6", "product_ss", "product1", "price_f", "1")
|
||||
.add(id, "14", "basket_s", "basket7", "product_ss", "product1", "price_f", "1")
|
||||
.add(id, "15", "basket_s", "basket4", "product_ss", "product1", "price_f", "1")
|
||||
.commit(cluster.getSolrClient(), COLLECTION);
|
||||
|
||||
List<Tuple> tuples = null;
|
||||
TupleStream stream = null;
|
||||
StreamContext context = new StreamContext();
|
||||
SolrClientCache cache = new SolrClientCache();
|
||||
context.setSolrClientCache(cache);
|
||||
|
||||
StreamFactory factory = new StreamFactory()
|
||||
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
|
||||
.withDefaultZkHost(cluster.getZkServer().getZkAddress())
|
||||
.withFunctionName("gatherNodes", GatherNodesStream.class)
|
||||
.withFunctionName("scoreNodes", ScoreNodesStream.class)
|
||||
.withFunctionName("search", CloudSolrStream.class)
|
||||
.withFunctionName("facet", FacetStream.class)
|
||||
.withFunctionName("sort", SortStream.class)
|
||||
.withFunctionName("count", CountMetric.class)
|
||||
.withFunctionName("avg", MeanMetric.class)
|
||||
.withFunctionName("sum", SumMetric.class)
|
||||
.withFunctionName("min", MinMetric.class)
|
||||
.withFunctionName("max", MaxMetric.class);
|
||||
|
||||
String expr = "sort(by=\"nodeScore desc\",scoreNodes(facet(collection1, q=\"product_ss:product3\", buckets=\"product_ss\", bucketSorts=\"count(*) desc\", bucketSizeLimit=100, count(*))))";
|
||||
|
||||
stream = factory.constructStream(expr);
|
||||
|
||||
context = new StreamContext();
|
||||
context.setSolrClientCache(cache);
|
||||
|
||||
stream.setStreamContext(context);
|
||||
tuples = getTuples(stream);
|
||||
|
||||
//The highest scoring tuple will be the product searched for.
|
||||
Tuple tuple = tuples.get(0);
|
||||
assert(tuple.getString("node").equals("product3"));
|
||||
assert(tuple.getLong("docFreq") == 3);
|
||||
assert(tuple.getLong("count(*)") == 3);
|
||||
|
||||
Tuple tuple0 = tuples.get(1);
|
||||
assert(tuple0.getString("node").equals("product4"));
|
||||
assert(tuple0.getLong("docFreq") == 2);
|
||||
assert(tuple0.getLong("count(*)") == 2);
|
||||
|
||||
Tuple tuple1 = tuples.get(2);
|
||||
assert(tuple1.getString("node").equals("product1"));
|
||||
assert(tuple1.getLong("docFreq") == 8);
|
||||
assert(tuple1.getLong("count(*)") == 3);
|
||||
|
||||
Tuple tuple2 = tuples.get(3);
|
||||
assert(tuple2.getString("node").equals("product5"));
|
||||
assert(tuple2.getLong("docFreq") == 1);
|
||||
assert(tuple2.getLong("count(*)") == 1);
|
||||
|
||||
|
||||
cache.close();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testGatherNodesFriendsStream() throws Exception {
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user