mirror of
https://github.com/apache/lucene.git
synced 2025-02-14 22:16:00 +00:00
SOLR-9537: Support facet scoring with the scoreNodes expression
This commit is contained in:
parent
b8dd3be93a
commit
96af65257d
@ -201,6 +201,14 @@ public class FacetStream extends TupleStream implements Expressible {
|
|||||||
init(collectionName, params, buckets, bucketSorts, metrics, limitInt, zkHost);
|
init(collectionName, params, buckets, bucketSorts, metrics, limitInt, zkHost);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Bucket[] getBuckets() {
|
||||||
|
return this.buckets;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getCollection() {
|
||||||
|
return this.collection;
|
||||||
|
}
|
||||||
|
|
||||||
private FieldComparator[] parseBucketSorts(String bucketSortString) throws IOException {
|
private FieldComparator[] parseBucketSorts(String bucketSortString) throws IOException {
|
||||||
|
|
||||||
String[] sorts = bucketSortString.split(",");
|
String[] sorts = bucketSortString.split(",");
|
||||||
|
@ -65,6 +65,10 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
|||||||
private Map<String, Tuple> nodes = new HashMap();
|
private Map<String, Tuple> nodes = new HashMap();
|
||||||
private Iterator<Tuple> tuples;
|
private Iterator<Tuple> tuples;
|
||||||
private String termFreq;
|
private String termFreq;
|
||||||
|
private boolean facet;
|
||||||
|
|
||||||
|
private String bucket;
|
||||||
|
private String facetCollection;
|
||||||
|
|
||||||
public ScoreNodesStream(TupleStream tupleStream, String nodeFreqField) throws IOException {
|
public ScoreNodesStream(TupleStream tupleStream, String nodeFreqField) throws IOException {
|
||||||
init(tupleStream, nodeFreqField);
|
init(tupleStream, nodeFreqField);
|
||||||
@ -98,6 +102,17 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
|||||||
private void init(TupleStream tupleStream, String termFreq) throws IOException{
|
private void init(TupleStream tupleStream, String termFreq) throws IOException{
|
||||||
this.stream = tupleStream;
|
this.stream = tupleStream;
|
||||||
this.termFreq = termFreq;
|
this.termFreq = termFreq;
|
||||||
|
if(stream instanceof FacetStream) {
|
||||||
|
FacetStream facetStream = (FacetStream) stream;
|
||||||
|
|
||||||
|
if(facetStream.getBuckets().length != 1) {
|
||||||
|
throw new IOException("scoreNodes operates over a single bucket. Num buckets:"+facetStream.getBuckets().length);
|
||||||
|
}
|
||||||
|
|
||||||
|
this.bucket = facetStream.getBuckets()[0].toString();
|
||||||
|
this.facetCollection = facetStream.getCollection();
|
||||||
|
this.facet = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -164,13 +179,21 @@ public class ScoreNodesStream extends TupleStream implements Expressible
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(facet) {
|
||||||
|
//Turn the facet tuple into a node.
|
||||||
|
String nodeId = node.getString(bucket);
|
||||||
|
node.put("node", nodeId);
|
||||||
|
node.remove(bucket);
|
||||||
|
node.put("collection", facetCollection);
|
||||||
|
node.put("field", bucket);
|
||||||
|
}
|
||||||
|
|
||||||
if(!node.fields.containsKey("node")) {
|
if(!node.fields.containsKey("node")) {
|
||||||
throw new IOException("node field not present in the Tuple");
|
throw new IOException("node field not present in the Tuple");
|
||||||
}
|
}
|
||||||
|
|
||||||
String nodeId = node.getString("node");
|
String nodeId = node.getString("node");
|
||||||
|
|
||||||
|
|
||||||
nodes.put(nodeId, node);
|
nodes.put(nodeId, node);
|
||||||
if(builder.length() > 0) {
|
if(builder.length() > 0) {
|
||||||
builder.append(",");
|
builder.append(",");
|
||||||
|
@ -39,6 +39,7 @@ import org.apache.solr.client.solrj.io.Tuple;
|
|||||||
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
|
import org.apache.solr.client.solrj.io.comp.ComparatorOrder;
|
||||||
import org.apache.solr.client.solrj.io.comp.FieldComparator;
|
import org.apache.solr.client.solrj.io.comp.FieldComparator;
|
||||||
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
|
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
|
||||||
|
import org.apache.solr.client.solrj.io.stream.FacetStream;
|
||||||
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
|
import org.apache.solr.client.solrj.io.stream.HashJoinStream;
|
||||||
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
|
import org.apache.solr.client.solrj.io.stream.ScoreNodesStream;
|
||||||
import org.apache.solr.client.solrj.io.stream.SortStream;
|
import org.apache.solr.client.solrj.io.stream.SortStream;
|
||||||
@ -510,6 +511,80 @@ public class GraphExpressionTest extends SolrCloudTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testScoreNodesFacetStream() throws Exception {
|
||||||
|
|
||||||
|
|
||||||
|
new UpdateRequest()
|
||||||
|
.add(id, "0", "basket_s", "basket1", "product_ss", "product1", "product_ss", "product3", "product_ss", "product5", "price_f", "1")
|
||||||
|
.add(id, "3", "basket_s", "basket2", "product_ss", "product1", "product_ss", "product6", "product_ss", "product7", "price_f", "1")
|
||||||
|
.add(id, "6", "basket_s", "basket3", "product_ss", "product4", "product_ss","product3", "product_ss","product1", "price_f", "1")
|
||||||
|
.add(id, "9", "basket_s", "basket4", "product_ss", "product4", "product_ss", "product3", "product_ss", "product1","price_f", "1")
|
||||||
|
.add(id, "12", "basket_s", "basket5", "product_ss", "product1", "price_f", "1")
|
||||||
|
.add(id, "13", "basket_s", "basket6", "product_ss", "product1", "price_f", "1")
|
||||||
|
.add(id, "14", "basket_s", "basket7", "product_ss", "product1", "price_f", "1")
|
||||||
|
.add(id, "15", "basket_s", "basket4", "product_ss", "product1", "price_f", "1")
|
||||||
|
.commit(cluster.getSolrClient(), COLLECTION);
|
||||||
|
|
||||||
|
List<Tuple> tuples = null;
|
||||||
|
TupleStream stream = null;
|
||||||
|
StreamContext context = new StreamContext();
|
||||||
|
SolrClientCache cache = new SolrClientCache();
|
||||||
|
context.setSolrClientCache(cache);
|
||||||
|
|
||||||
|
StreamFactory factory = new StreamFactory()
|
||||||
|
.withCollectionZkHost("collection1", cluster.getZkServer().getZkAddress())
|
||||||
|
.withDefaultZkHost(cluster.getZkServer().getZkAddress())
|
||||||
|
.withFunctionName("gatherNodes", GatherNodesStream.class)
|
||||||
|
.withFunctionName("scoreNodes", ScoreNodesStream.class)
|
||||||
|
.withFunctionName("search", CloudSolrStream.class)
|
||||||
|
.withFunctionName("facet", FacetStream.class)
|
||||||
|
.withFunctionName("sort", SortStream.class)
|
||||||
|
.withFunctionName("count", CountMetric.class)
|
||||||
|
.withFunctionName("avg", MeanMetric.class)
|
||||||
|
.withFunctionName("sum", SumMetric.class)
|
||||||
|
.withFunctionName("min", MinMetric.class)
|
||||||
|
.withFunctionName("max", MaxMetric.class);
|
||||||
|
|
||||||
|
String expr = "sort(by=\"nodeScore desc\",scoreNodes(facet(collection1, q=\"product_ss:product3\", buckets=\"product_ss\", bucketSorts=\"count(*) desc\", bucketSizeLimit=100, count(*))))";
|
||||||
|
|
||||||
|
stream = factory.constructStream(expr);
|
||||||
|
|
||||||
|
context = new StreamContext();
|
||||||
|
context.setSolrClientCache(cache);
|
||||||
|
|
||||||
|
stream.setStreamContext(context);
|
||||||
|
tuples = getTuples(stream);
|
||||||
|
|
||||||
|
//The highest scoring tuple will be the product searched for.
|
||||||
|
Tuple tuple = tuples.get(0);
|
||||||
|
assert(tuple.getString("node").equals("product3"));
|
||||||
|
assert(tuple.getLong("docFreq") == 3);
|
||||||
|
assert(tuple.getLong("count(*)") == 3);
|
||||||
|
|
||||||
|
Tuple tuple0 = tuples.get(1);
|
||||||
|
assert(tuple0.getString("node").equals("product4"));
|
||||||
|
assert(tuple0.getLong("docFreq") == 2);
|
||||||
|
assert(tuple0.getLong("count(*)") == 2);
|
||||||
|
|
||||||
|
Tuple tuple1 = tuples.get(2);
|
||||||
|
assert(tuple1.getString("node").equals("product1"));
|
||||||
|
assert(tuple1.getLong("docFreq") == 8);
|
||||||
|
assert(tuple1.getLong("count(*)") == 3);
|
||||||
|
|
||||||
|
Tuple tuple2 = tuples.get(3);
|
||||||
|
assert(tuple2.getString("node").equals("product5"));
|
||||||
|
assert(tuple2.getLong("docFreq") == 1);
|
||||||
|
assert(tuple2.getLong("count(*)") == 1);
|
||||||
|
|
||||||
|
|
||||||
|
cache.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGatherNodesFriendsStream() throws Exception {
|
public void testGatherNodesFriendsStream() throws Exception {
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user